[Mlir-commits] [mlir] 108abd2 - [mlir] Add a new MutableOperandRange class for adding/remove operands
River Riddle
llvmlistbot at llvm.org
Wed Apr 29 16:49:42 PDT 2020
Author: River Riddle
Date: 2020-04-29T16:48:14-07:00
New Revision: 108abd2f2eaf9a197f92fcc672777f8561d330dd
URL: https://github.com/llvm/llvm-project/commit/108abd2f2eaf9a197f92fcc672777f8561d330dd
DIFF: https://github.com/llvm/llvm-project/commit/108abd2f2eaf9a197f92fcc672777f8561d330dd.diff
LOG: [mlir] Add a new MutableOperandRange class for adding/remove operands
This class allows for mutating an operand range in-place, and provides vector like API for adding/erasing/setting. ODS now uses this class to generate mutable wrappers for named operands, with the name `MutableOperandRange <operand-name>Mutable()`
Differential Revision: https://reviews.llvm.org/D78892
Added:
Modified:
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/UseDefLists.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/test/mlir-tblgen/op-decl.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/unittests/IR/OperationSupportTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index e42f50b7ac32..43628b8ce506 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -296,6 +296,10 @@ class DictionaryAttr
Attribute get(StringRef name) const;
Attribute get(Identifier name) const;
+ /// Return the specified named attribute if present, None otherwise.
+ Optional<NamedAttribute> getNamed(StringRef name) const;
+ Optional<NamedAttribute> getNamed(Identifier name) const;
+
/// Support range iteration.
using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
iterator begin() const;
@@ -1513,6 +1517,10 @@ class MutableDictionaryAttr {
Attribute get(StringRef name) const;
Attribute get(Identifier name) const;
+ /// Return the specified named attribute if present, None otherwise.
+ Optional<NamedAttribute> getNamed(StringRef name) const;
+ Optional<NamedAttribute> getNamed(Identifier name) const;
+
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void set(Identifier name, Attribute value);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index bd3076780a15..5c9408199abf 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -205,6 +205,14 @@ class Operation final
/// 'operands'.
void setOperands(ValueRange operands);
+ /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
+ /// with the ones provided in 'operands'. 'operands' may be smaller or larger
+ /// than the range pointed to by 'start'+'length'.
+ void setOperands(unsigned start, unsigned length, ValueRange operands);
+
+ /// Insert the given operands into the operand list at the given 'index'.
+ void insertOperands(unsigned index, ValueRange operands);
+
unsigned getNumOperands() {
return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().size() : 0;
}
@@ -214,6 +222,15 @@ class Operation final
return getOpOperand(idx).set(value);
}
+ /// Erase the operand at position `idx`.
+ void eraseOperand(unsigned idx) { eraseOperands(idx); }
+
+ /// Erase the operands starting at position `idx` and ending at position
+ /// 'idx'+'length'.
+ void eraseOperands(unsigned idx, unsigned length = 1) {
+ getOperandStorage().eraseOperands(idx, length);
+ }
+
// Support operand iteration.
using operand_range = OperandRange;
using operand_iterator = operand_range::iterator;
@@ -221,12 +238,9 @@ class Operation final
operand_iterator operand_begin() { return getOperands().begin(); }
operand_iterator operand_end() { return getOperands().end(); }
- /// Returns an iterator on the underlying Value's (Value ).
+ /// Returns an iterator on the underlying Value's.
operand_range getOperands() { return operand_range(this); }
- /// Erase the operand at position `idx`.
- void eraseOperand(unsigned idx) { getOperandStorage().eraseOperand(idx); }
-
MutableArrayRef<OpOperand> getOpOperands() {
return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().getOperands()
: MutableArrayRef<OpOperand>();
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 852404f4e071..1beb8d14151c 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -369,8 +369,14 @@ class OperandStorage final
/// 'values'.
void setOperands(Operation *owner, ValueRange values);
- /// Erase an operand held by the storage.
- void eraseOperand(unsigned index);
+ /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
+ /// with the ones provided in 'operands'. 'operands' may be smaller or larger
+ /// than the range pointed to by 'start'+'length'.
+ void setOperands(Operation *owner, unsigned start, unsigned length,
+ ValueRange operands);
+
+ /// Erase the operands held by the storage within the given range.
+ void eraseOperands(unsigned start, unsigned length);
/// Get the operation operands held by the storage.
MutableArrayRef<OpOperand> getOperands() {
@@ -653,6 +659,62 @@ class OperandRange final : public llvm::detail::indexed_accessor_range_base<
friend RangeBaseT;
};
+//===----------------------------------------------------------------------===//
+// MutableOperandRange
+
+/// This class provides a mutable adaptor for a range of operands. It allows for
+/// setting, inserting, and erasing operands from the given range.
+class MutableOperandRange {
+public:
+ /// A pair of a named attribute corresponding to an operand segment attribute,
+ /// and the index within that attribute. The attribute should correspond to an
+ /// i32 DenseElementsAttr.
+ using OperandSegment = std::pair<unsigned, NamedAttribute>;
+
+ /// Construct a new mutable range from the given operand, operand start index,
+ /// and range length. `operandSegments` is an optional set of operand segments
+ /// to be updated when mutating the operand list.
+ MutableOperandRange(Operation *owner, unsigned start, unsigned length,
+ ArrayRef<OperandSegment> operandSegments = llvm::None);
+ MutableOperandRange(Operation *owner);
+
+ /// Append the given values to the range.
+ void append(ValueRange values);
+
+ /// Assign this range to the given values.
+ void assign(ValueRange values);
+
+ /// Assign the range to the given value.
+ void assign(Value value);
+
+ /// Erase the operands within the given sub-range.
+ void erase(unsigned subStart, unsigned subLen = 1);
+
+ /// Clear this range and erase all of the operands.
+ void clear();
+
+ /// Returns the current size of the range.
+ unsigned size() const { return length; }
+
+ /// Allow implicit conversion to an OperandRange.
+ operator OperandRange() const;
+
+private:
+ /// Update the length of this range to the one provided.
+ void updateLength(unsigned newLength);
+
+ /// The owning operation of this range.
+ Operation *owner;
+
+ /// The start index of the operand range within the owner operand list, and
+ /// the length starting from `start`.
+ unsigned start, length;
+
+ /// Optional set of operand segments that should be updated when mutating the
+ /// length of this range.
+ SmallVector<std::pair<unsigned, NamedAttribute>, 1> operandSegments;
+};
+
//===----------------------------------------------------------------------===//
// ResultRange
diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h
index 06df9edbfa8b..1a6319d2b2e8 100644
--- a/mlir/include/mlir/IR/UseDefLists.h
+++ b/mlir/include/mlir/IR/UseDefLists.h
@@ -164,7 +164,8 @@ template <typename DerivedT, typename IRValueTy> class IROperand {
other.back = nullptr;
nextUse = nullptr;
back = nullptr;
- insertIntoCurrent();
+ if (value)
+ insertIntoCurrent();
return *this;
}
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 40f01e0d21ea..f51b273a9c89 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -196,15 +196,26 @@ ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
/// Return the specified attribute if present, null otherwise.
Attribute DictionaryAttr::get(StringRef name) const {
+ Optional<NamedAttribute> attr = getNamed(name);
+ return attr ? attr->second : nullptr;
+}
+Attribute DictionaryAttr::get(Identifier name) const {
+ Optional<NamedAttribute> attr = getNamed(name);
+ return attr ? attr->second : nullptr;
+}
+
+/// Return the specified named attribute if present, None otherwise.
+Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
ArrayRef<NamedAttribute> values = getValue();
auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
- return it != values.end() && it->first == name ? it->second : Attribute();
+ return it != values.end() && it->first == name ? *it
+ : Optional<NamedAttribute>();
}
-Attribute DictionaryAttr::get(Identifier name) const {
+Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
for (auto elt : getValue())
if (elt.first == name)
- return elt.second;
- return nullptr;
+ return elt;
+ return llvm::None;
}
DictionaryAttr::iterator DictionaryAttr::begin() const {
@@ -1191,6 +1202,15 @@ Attribute MutableDictionaryAttr::get(Identifier name) const {
return attrs ? attrs.get(name) : nullptr;
}
+/// Return the specified named attribute if present, None otherwise.
+Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
+ return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
+}
+Optional<NamedAttribute>
+MutableDictionaryAttr::getNamed(Identifier name) const {
+ return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
+}
+
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void MutableDictionaryAttr::set(Identifier name, Attribute value) {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index c8f2d67e3e65..5b439d67ad67 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -244,6 +244,25 @@ void Operation::setOperands(ValueRange operands) {
assert(operands.empty() && "setting operands without an operand storage");
}
+/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
+/// with the ones provided in 'operands'. 'operands' may be smaller or larger
+/// than the range pointed to by 'start'+'length'.
+void Operation::setOperands(unsigned start, unsigned length,
+ ValueRange operands) {
+ assert((start + length) <= getNumOperands() &&
+ "invalid operand range specified");
+ if (LLVM_LIKELY(hasOperandStorage))
+ return getOperandStorage().setOperands(this, start, length, operands);
+ assert(operands.empty() && "setting operands without an operand storage");
+}
+
+/// Insert the given operands into the operand list at the given 'index'.
+void Operation::insertOperands(unsigned index, ValueRange operands) {
+ if (LLVM_LIKELY(hasOperandStorage))
+ return setOperands(index, /*length=*/0, operands);
+ assert(operands.empty() && "inserting operands without an operand storage");
+}
+
//===----------------------------------------------------------------------===//
// Diagnostics
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 5f698b1bdc53..087828e6e519 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -13,7 +13,9 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Block.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -89,6 +91,55 @@ void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
storageOperands[i].set(values[i]);
}
+/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
+/// with the ones provided in 'operands'. 'operands' may be smaller or larger
+/// than the range pointed to by 'start'+'length'.
+void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
+ unsigned length, ValueRange operands) {
+ // If the new size is the same, we can update inplace.
+ unsigned newSize = operands.size();
+ if (newSize == length) {
+ MutableArrayRef<OpOperand> storageOperands = getOperands();
+ for (unsigned i = 0, e = length; i != e; ++i)
+ storageOperands[start + i].set(operands[i]);
+ return;
+ }
+ // If the new size is greater, remove the extra operands and set the rest
+ // inplace.
+ if (newSize < length) {
+ eraseOperands(start + operands.size(), length - newSize);
+ setOperands(owner, start, newSize, operands);
+ return;
+ }
+ // Otherwise, the new size is greater so we need to grow the storage.
+ auto storageOperands = resize(owner, size() + (newSize - length));
+
+ // Shift operands to the right to make space for the new operands.
+ unsigned rotateSize = storageOperands.size() - (start + length);
+ auto rbegin = storageOperands.rbegin();
+ std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
+
+ // Update the operands inplace.
+ for (unsigned i = 0, e = operands.size(); i != e; ++i)
+ storageOperands[start + i].set(operands[i]);
+}
+
+/// Erase an operand held by the storage.
+void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
+ TrailingOperandStorage &storage = getStorage();
+ MutableArrayRef<OpOperand> operands = storage.getOperands();
+ assert((start + length) <= operands.size());
+ storage.numOperands -= length;
+
+ // Shift all operands down if the operand to remove is not at the end.
+ if (start != storage.numOperands) {
+ auto indexIt = std::next(operands.begin(), start);
+ std::rotate(indexIt, std::next(indexIt, length), operands.end());
+ }
+ for (unsigned i = 0; i != length; ++i)
+ operands[storage.numOperands + i].~OpOperand();
+}
+
/// Resize the storage to the given size. Returns the array containing the new
/// operands.
MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
@@ -149,20 +200,6 @@ MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
return newOperands;
}
-/// Erase an operand held by the storage.
-void detail::OperandStorage::eraseOperand(unsigned index) {
- assert(index < size());
- TrailingOperandStorage &storage = getStorage();
- MutableArrayRef<OpOperand> operands = storage.getOperands();
- --storage.numOperands;
-
- // Shift all operands down by 1 if the operand to remove is not at the end.
- auto indexIt = std::next(operands.begin(), index);
- if (index != storage.numOperands)
- std::rotate(indexIt, std::next(indexIt), operands.end());
- operands[storage.numOperands].~OpOperand();
-}
-
//===----------------------------------------------------------------------===//
// ResultStorage
//===----------------------------------------------------------------------===//
@@ -235,6 +272,83 @@ unsigned OperandRange::getBeginOperandIndex() const {
return base->getOperandNumber();
}
+//===----------------------------------------------------------------------===//
+// MutableOperandRange
+
+/// Construct a new mutable range from the given operand, operand start index,
+/// and range length.
+MutableOperandRange::MutableOperandRange(
+ Operation *owner, unsigned start, unsigned length,
+ ArrayRef<OperandSegment> operandSegments)
+ : owner(owner), start(start), length(length),
+ operandSegments(operandSegments.begin(), operandSegments.end()) {
+ assert((start + length) <= owner->getNumOperands() && "invalid range");
+}
+MutableOperandRange::MutableOperandRange(Operation *owner)
+ : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
+
+/// Append the given values to the range.
+void MutableOperandRange::append(ValueRange values) {
+ if (values.empty())
+ return;
+ owner->insertOperands(start + length, values);
+ updateLength(length + values.size());
+}
+
+/// Assign this range to the given values.
+void MutableOperandRange::assign(ValueRange values) {
+ owner->setOperands(start, length, values);
+ if (length != values.size())
+ updateLength(/*newLength=*/values.size());
+}
+
+/// Assign the range to the given value.
+void MutableOperandRange::assign(Value value) {
+ if (length == 1) {
+ owner->setOperand(start, value);
+ } else {
+ owner->setOperands(start, length, value);
+ updateLength(/*newLength=*/1);
+ }
+}
+
+/// Erase the operands within the given sub-range.
+void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
+ assert((subStart + subLen) <= length && "invalid sub-range");
+ if (length == 0)
+ return;
+ owner->eraseOperands(start + subStart, subLen);
+ updateLength(length - subLen);
+}
+
+/// Clear this range and erase all of the operands.
+void MutableOperandRange::clear() {
+ if (length != 0) {
+ owner->eraseOperands(start, length);
+ updateLength(/*newLength=*/0);
+ }
+}
+
+/// Allow implicit conversion to an OperandRange.
+MutableOperandRange::operator OperandRange() const {
+ return owner->getOperands().slice(start, length);
+}
+
+/// Update the length of this range to the one provided.
+void MutableOperandRange::updateLength(unsigned newLength) {
+ int32_t
diff = int32_t(newLength) - int32_t(length);
+ length = newLength;
+
+ // Update any of the provided segment attributes.
+ for (OperandSegment &segment : operandSegments) {
+ auto attr = segment.second.second.cast<DenseIntElementsAttr>();
+ SmallVector<int32_t, 8> segments(attr.getValues<int32_t>());
+ segments[segment.first] +=
diff ;
+ segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments);
+ owner->setAttr(segment.second.first, segment.second.second);
+ }
+}
+
//===----------------------------------------------------------------------===//
// ResultRange
diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 69f1cfe5d8aa..4fa2f3af29e1 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -67,6 +67,8 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: Operation::operand_range getODSOperands(unsigned index);
// CHECK: Value a();
// CHECK: Operation::operand_range b();
+// CHECK: ::mlir::MutableOperandRange aMutable();
+// CHECK: ::mlir::MutableOperandRange bMutable();
// CHECK: Operation::result_range getODSResults(unsigned index);
// CHECK: Value r();
// CHECK: Region &someRegion();
@@ -119,6 +121,7 @@ def NS_EOp : NS_Op<"op_with_optionals", []> {
// CHECK-LABEL: NS::EOp declarations
// CHECK: Value a();
+// CHECK: ::mlir::MutableOperandRange aMutable();
// CHECK: Value b();
// CHECK: static void build(OpBuilder &odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a)
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index bbfc06f883f8..35fb291498bb 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -45,25 +45,23 @@ static const char *const builderOpState = "odsState";
// {1}: The total number of non-variadic operands/results.
// {2}: The total number of variadic operands/results.
// {3}: The total number of actual values.
-// {4}: The begin iterator of the actual values.
-// {5}: "operand" or "result".
+// {4}: "operand" or "result".
const char *sameVariadicSizeValueRangeCalcCode = R"(
bool isVariadic[] = {{{0}};
int prevVariadicCount = 0;
for (unsigned i = 0; i < index; ++i)
if (isVariadic[i]) ++prevVariadicCount;
- // Calculate how many dynamic values a static variadic {5} corresponds to.
- // This assumes all static variadic {5}s have the same dynamic value count.
+ // Calculate how many dynamic values a static variadic {4} corresponds to.
+ // This assumes all static variadic {4}s have the same dynamic value count.
int variadicSize = ({3} - {1}) / {2};
// `index` passed in as the parameter is the static index which counts each
- // {5} (variadic or not) as size 1. So here for each previous static variadic
- // {5}, we need to offset by (variadicSize - 1) to get where the dynamic
- // value pack for this static {5} starts.
- int offset = index + (variadicSize - 1) * prevVariadicCount;
+ // {4} (variadic or not) as size 1. So here for each previous static variadic
+ // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
+ // value pack for this static {4} starts.
+ int start = index + (variadicSize - 1) * prevVariadicCount;
int size = isVariadic[index] ? variadicSize : 1;
-
- return {{std::next({4}, offset), std::next({4}, offset + size)};
+ return {{start, size};
)";
// The logic to calculate the actual value range for a declared operand/result
@@ -72,14 +70,23 @@ const char *sameVariadicSizeValueRangeCalcCode = R"(
// (variadic or not).
//
// {0}: The name of the attribute specifying the segment sizes.
-// {1}: The begin iterator of the actual values.
const char *attrSizedSegmentValueRangeCalcCode = R"(
auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
unsigned start = 0;
for (unsigned i = 0; i < index; ++i)
start += (*(sizeAttr.begin() + i)).getZExtValue();
- unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue();
- return {{std::next({1}, start), std::next({1}, end)};
+ unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
+ return {{start, size};
+)";
+
+// The logic to build a range of either operand or result values.
+//
+// {0}: The begin iterator of the actual values.
+// {1}: The call to generate the start and length of the value range.
+const char *valueRangeReturnCode = R"(
+ auto valueRange = {1};
+ return {{std::next({0}, valueRange.first),
+ std::next({0}, valueRange.first + valueRange.second)};
)";
static const char *const opCommentHeader = R"(
@@ -177,6 +184,9 @@ class OpEmitter {
// Generates getters for named operands.
void genNamedOperandGetters();
+ // Generates setters for named operands.
+ void genNamedOperandSetters();
+
// Generates getters for named results.
void genNamedResultGetters();
@@ -310,6 +320,7 @@ OpEmitter::OpEmitter(const Operator &op)
genOpAsmInterface();
genOpNameGetter();
genNamedOperandGetters();
+ genNamedOperandSetters();
genNamedResultGetters();
genNamedRegionGetters();
genNamedSuccessorGetters();
@@ -478,6 +489,37 @@ void OpEmitter::genAttrSetters() {
}
}
+// Generates the code to compute the start and end index of an operand or result
+// range.
+template <typename RangeT>
+static void
+generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
+ int numVariadic, int numNonVariadic,
+ StringRef rangeSizeCall, bool hasAttrSegmentSize,
+ StringRef segmentSizeAttr, RangeT &&odsValues) {
+ auto &method = opClass.newMethod("std::pair<unsigned, unsigned>", methodName,
+ "unsigned index");
+
+ if (numVariadic == 0) {
+ method.body() << " return {index, 1};\n";
+ } else if (hasAttrSegmentSize) {
+ method.body() << formatv(attrSizedSegmentValueRangeCalcCode,
+ segmentSizeAttr);
+ } else {
+ // Because the op can have arbitrarily interleaved variadic and non-variadic
+ // operands, we need to embed a list in the "sink" getter method for
+ // calculation at run-time.
+ llvm::SmallVector<StringRef, 4> isVariadic;
+ isVariadic.reserve(llvm::size(odsValues));
+ for (auto &it : odsValues)
+ isVariadic.push_back(it.isVariableLength() ? "true" : "false");
+ std::string isVariadicList = llvm::join(isVariadic, ", ");
+ method.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
+ numNonVariadic, numVariadic, rangeSizeCall,
+ "operand");
+ }
+}
+
// Generates the named operand getter methods for the given Operator `op` and
// puts them in `opClass`. Uses `rangeType` as the return type of getters that
// return a range of operands (individual operands are `Value ` and each
@@ -519,32 +561,16 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
"'SameVariadicOperandSize' traits");
}
- // First emit a "sink" getter method upon which we layer all nicer named
+ // First emit a few "sink" getter methods upon which we layer all nicer named
// getter methods.
- auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
+ generateValueRangeStartAndEnd(
+ opClass, "getODSOperandIndexAndLength", numVariadicOperands,
+ numNormalOperands, rangeSizeCall, attrSizedOperands,
+ "operand_segment_sizes", const_cast<Operator &>(op).getOperands());
- if (numVariadicOperands == 0) {
- // We still need to match the return type, which is a range.
- m.body() << " return {std::next(" << rangeBeginCall
- << ", index), std::next(" << rangeBeginCall << ", index + 1)};";
- } else if (attrSizedOperands) {
- m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
- "operand_segment_sizes", rangeBeginCall);
- } else {
- // Because the op can have arbitrarily interleaved variadic and non-variadic
- // operands, we need to embed a list in the "sink" getter method for
- // calculation at run-time.
- llvm::SmallVector<StringRef, 4> isVariadic;
- isVariadic.reserve(numOperands);
- for (int i = 0; i < numOperands; ++i)
- isVariadic.push_back(op.getOperand(i).isVariableLength() ? "true"
- : "false");
- std::string isVariadicList = llvm::join(isVariadic, ", ");
-
- m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
- numNormalOperands, numVariadicOperands, rangeSizeCall,
- rangeBeginCall, "operand");
- }
+ auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
+ m.body() << formatv(valueRangeReturnCode, rangeBeginCall,
+ "getODSOperandIndexAndLength(index)");
// Then we emit nicer named getter methods by redirecting to the "sink" getter
// method.
@@ -579,6 +605,26 @@ void OpEmitter::genNamedOperandGetters() {
/*getOperandCallPattern=*/"getOperation()->getOperand({0})");
}
+void OpEmitter::genNamedOperandSetters() {
+ auto *attrSizedOperands = op.getTrait("OpTrait::AttrSizedOperandSegments");
+ for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
+ const auto &operand = op.getOperand(i);
+ if (operand.name.empty())
+ continue;
+ auto &m = opClass.newMethod("::mlir::MutableOperandRange",
+ (operand.name + "Mutable").str());
+ auto &body = m.body();
+ body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
+ << " return ::mlir::MutableOperandRange(getOperation(), "
+ "range.first, range.second";
+ if (attrSizedOperands)
+ body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
+ << "u, *getOperation()->getMutableAttrDict().getNamed("
+ "\"operand_segment_sizes\"))";
+ body << ");\n";
+ }
+}
+
void OpEmitter::genNamedResultGetters() {
const int numResults = op.getNumResults();
const int numVariadicResults = op.getNumVariableLengthResults();
@@ -607,29 +653,14 @@ void OpEmitter::genNamedResultGetters() {
"'SameVariadicResultSize' traits");
}
+ generateValueRangeStartAndEnd(
+ opClass, "getODSResultIndexAndLength", numVariadicResults,
+ numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
+ "result_segment_sizes", op.getResults());
auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
"unsigned index");
-
- if (numVariadicResults == 0) {
- m.body() << " return {std::next(getOperation()->result_begin(), index), "
- "std::next(getOperation()->result_begin(), index + 1)};";
- } else if (attrSizedResults) {
- m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
- "result_segment_sizes",
- "getOperation()->result_begin()");
- } else {
- llvm::SmallVector<StringRef, 4> isVariadic;
- isVariadic.reserve(numResults);
- for (int i = 0; i < numResults; ++i)
- isVariadic.push_back(op.getResult(i).isVariableLength() ? "true"
- : "false");
- std::string isVariadicList = llvm::join(isVariadic, ", ");
-
- m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
- numNormalResults, numVariadicResults,
- "getOperation()->getNumResults()",
- "getOperation()->result_begin()", "result");
- }
+ m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
+ "getODSResultIndexAndLength(index)");
for (int i = 0; i != numResults; ++i) {
const auto &result = op.getResult(i);
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 3cfb62553bfb..95ddcccc565e 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -33,7 +33,7 @@ TEST(OperandStorageTest, NonResizable) {
Value operand = useOp->getResult(0);
// Create a non-resizable operation with one operand.
- Operation *user = createOp(&context, operand, builder.getIntegerType(16));
+ Operation *user = createOp(&context, operand);
// The same number of operands is okay.
user->setOperands(operand);
@@ -57,7 +57,7 @@ TEST(OperandStorageTest, Resizable) {
Value operand = useOp->getResult(0);
// Create a resizable operation with one operand.
- Operation *user = createOp(&context, operand, builder.getIntegerType(16));
+ Operation *user = createOp(&context, operand);
// The same number of operands is okay.
user->setOperands(operand);
@@ -76,4 +76,77 @@ TEST(OperandStorageTest, Resizable) {
useOp->destroy();
}
+TEST(OperandStorageTest, RangeReplace) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ Operation *useOp =
+ createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16));
+ Value operand = useOp->getResult(0);
+
+ // Create a resizable operation with one operand.
+ Operation *user = createOp(&context, operand);
+
+ // Check setting with the same number of operands.
+ user->setOperands(/*start=*/0, /*length=*/1, operand);
+ EXPECT_EQ(user->getNumOperands(), 1u);
+
+ // Check setting with more operands.
+ user->setOperands(/*start=*/0, /*length=*/1, {operand, operand, operand});
+ EXPECT_EQ(user->getNumOperands(), 3u);
+
+ // Check setting with less operands.
+ user->setOperands(/*start=*/1, /*length=*/2, {operand});
+ EXPECT_EQ(user->getNumOperands(), 2u);
+
+ // Check inserting without replacing operands.
+ user->setOperands(/*start=*/2, /*length=*/0, {operand});
+ EXPECT_EQ(user->getNumOperands(), 3u);
+
+ // Check erasing operands.
+ user->setOperands(/*start=*/0, /*length=*/3, {});
+ EXPECT_EQ(user->getNumOperands(), 0u);
+
+ // Destroy the operations.
+ user->destroy();
+ useOp->destroy();
+}
+
+TEST(OperandStorageTest, MutableRange) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ Operation *useOp =
+ createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16));
+ Value operand = useOp->getResult(0);
+
+ // Create a resizable operation with one operand.
+ Operation *user = createOp(&context, operand);
+
+ // Check setting with the same number of operands.
+ MutableOperandRange mutableOperands(user);
+ mutableOperands.assign(operand);
+ EXPECT_EQ(mutableOperands.size(), 1u);
+ EXPECT_EQ(user->getNumOperands(), 1u);
+
+ // Check setting with more operands.
+ mutableOperands.assign({operand, operand, operand});
+ EXPECT_EQ(mutableOperands.size(), 3u);
+ EXPECT_EQ(user->getNumOperands(), 3u);
+
+ // Check with inserting a new operand.
+ mutableOperands.append({operand, operand});
+ EXPECT_EQ(mutableOperands.size(), 5u);
+ EXPECT_EQ(user->getNumOperands(), 5u);
+
+ // Check erasing operands.
+ mutableOperands.clear();
+ EXPECT_EQ(mutableOperands.size(), 0u);
+ EXPECT_EQ(user->getNumOperands(), 0u);
+
+ // Destroy the operations.
+ user->destroy();
+ useOp->destroy();
+}
+
} // end namespace
More information about the Mlir-commits
mailing list