[Mlir-commits] [mlir] 108abd2 - [mlir] Add a new MutableOperandRange class for adding/remove operands

River Riddle llvmlistbot at llvm.org
Wed Apr 29 16:49:42 PDT 2020


Author: River Riddle
Date: 2020-04-29T16:48:14-07:00
New Revision: 108abd2f2eaf9a197f92fcc672777f8561d330dd

URL: https://github.com/llvm/llvm-project/commit/108abd2f2eaf9a197f92fcc672777f8561d330dd
DIFF: https://github.com/llvm/llvm-project/commit/108abd2f2eaf9a197f92fcc672777f8561d330dd.diff

LOG: [mlir] Add a new MutableOperandRange class for adding/remove operands

This class allows for mutating an operand range in-place, and provides vector like API for adding/erasing/setting. ODS now uses this class to generate mutable wrappers for named operands, with the name `MutableOperandRange <operand-name>Mutable()`

Differential Revision: https://reviews.llvm.org/D78892

Added: 
    

Modified: 
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/IR/UseDefLists.h
    mlir/lib/IR/Attributes.cpp
    mlir/lib/IR/Operation.cpp
    mlir/lib/IR/OperationSupport.cpp
    mlir/test/mlir-tblgen/op-decl.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/unittests/IR/OperationSupportTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index e42f50b7ac32..43628b8ce506 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -296,6 +296,10 @@ class DictionaryAttr
   Attribute get(StringRef name) const;
   Attribute get(Identifier name) const;
 
+  /// Return the specified named attribute if present, None otherwise.
+  Optional<NamedAttribute> getNamed(StringRef name) const;
+  Optional<NamedAttribute> getNamed(Identifier name) const;
+
   /// Support range iteration.
   using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
   iterator begin() const;
@@ -1513,6 +1517,10 @@ class MutableDictionaryAttr {
   Attribute get(StringRef name) const;
   Attribute get(Identifier name) const;
 
+  /// Return the specified named attribute if present, None otherwise.
+  Optional<NamedAttribute> getNamed(StringRef name) const;
+  Optional<NamedAttribute> getNamed(Identifier name) const;
+
   /// If the an attribute exists with the specified name, change it to the new
   /// value.  Otherwise, add a new attribute with the specified name/value.
   void set(Identifier name, Attribute value);

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index bd3076780a15..5c9408199abf 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -205,6 +205,14 @@ class Operation final
   /// 'operands'.
   void setOperands(ValueRange operands);
 
+  /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
+  /// with the ones provided in 'operands'. 'operands' may be smaller or larger
+  /// than the range pointed to by 'start'+'length'.
+  void setOperands(unsigned start, unsigned length, ValueRange operands);
+
+  /// Insert the given operands into the operand list at the given 'index'.
+  void insertOperands(unsigned index, ValueRange operands);
+
   unsigned getNumOperands() {
     return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().size() : 0;
   }
@@ -214,6 +222,15 @@ class Operation final
     return getOpOperand(idx).set(value);
   }
 
+  /// Erase the operand at position `idx`.
+  void eraseOperand(unsigned idx) { eraseOperands(idx); }
+
+  /// Erase the operands starting at position `idx` and ending at position
+  /// 'idx'+'length'.
+  void eraseOperands(unsigned idx, unsigned length = 1) {
+    getOperandStorage().eraseOperands(idx, length);
+  }
+
   // Support operand iteration.
   using operand_range = OperandRange;
   using operand_iterator = operand_range::iterator;
@@ -221,12 +238,9 @@ class Operation final
   operand_iterator operand_begin() { return getOperands().begin(); }
   operand_iterator operand_end() { return getOperands().end(); }
 
-  /// Returns an iterator on the underlying Value's (Value ).
+  /// Returns an iterator on the underlying Value's.
   operand_range getOperands() { return operand_range(this); }
 
-  /// Erase the operand at position `idx`.
-  void eraseOperand(unsigned idx) { getOperandStorage().eraseOperand(idx); }
-
   MutableArrayRef<OpOperand> getOpOperands() {
     return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().getOperands()
                                           : MutableArrayRef<OpOperand>();

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 852404f4e071..1beb8d14151c 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -369,8 +369,14 @@ class OperandStorage final
   /// 'values'.
   void setOperands(Operation *owner, ValueRange values);
 
-  /// Erase an operand held by the storage.
-  void eraseOperand(unsigned index);
+  /// Replace the operands beginning at 'start' and ending at 'start' + 'length'
+  /// with the ones provided in 'operands'. 'operands' may be smaller or larger
+  /// than the range pointed to by 'start'+'length'.
+  void setOperands(Operation *owner, unsigned start, unsigned length,
+                   ValueRange operands);
+
+  /// Erase the operands held by the storage within the given range.
+  void eraseOperands(unsigned start, unsigned length);
 
   /// Get the operation operands held by the storage.
   MutableArrayRef<OpOperand> getOperands() {
@@ -653,6 +659,62 @@ class OperandRange final : public llvm::detail::indexed_accessor_range_base<
   friend RangeBaseT;
 };
 
+//===----------------------------------------------------------------------===//
+// MutableOperandRange
+
+/// This class provides a mutable adaptor for a range of operands. It allows for
+/// setting, inserting, and erasing operands from the given range.
+class MutableOperandRange {
+public:
+  /// A pair of a named attribute corresponding to an operand segment attribute,
+  /// and the index within that attribute. The attribute should correspond to an
+  /// i32 DenseElementsAttr.
+  using OperandSegment = std::pair<unsigned, NamedAttribute>;
+
+  /// Construct a new mutable range from the given operand, operand start index,
+  /// and range length. `operandSegments` is an optional set of operand segments
+  /// to be updated when mutating the operand list.
+  MutableOperandRange(Operation *owner, unsigned start, unsigned length,
+                      ArrayRef<OperandSegment> operandSegments = llvm::None);
+  MutableOperandRange(Operation *owner);
+
+  /// Append the given values to the range.
+  void append(ValueRange values);
+
+  /// Assign this range to the given values.
+  void assign(ValueRange values);
+
+  /// Assign the range to the given value.
+  void assign(Value value);
+
+  /// Erase the operands within the given sub-range.
+  void erase(unsigned subStart, unsigned subLen = 1);
+
+  /// Clear this range and erase all of the operands.
+  void clear();
+
+  /// Returns the current size of the range.
+  unsigned size() const { return length; }
+
+  /// Allow implicit conversion to an OperandRange.
+  operator OperandRange() const;
+
+private:
+  /// Update the length of this range to the one provided.
+  void updateLength(unsigned newLength);
+
+  /// The owning operation of this range.
+  Operation *owner;
+
+  /// The start index of the operand range within the owner operand list, and
+  /// the length starting from `start`.
+  unsigned start, length;
+
+  /// Optional set of operand segments that should be updated when mutating the
+  /// length of this range.
+  SmallVector<std::pair<unsigned, NamedAttribute>, 1> operandSegments;
+};
+
 //===----------------------------------------------------------------------===//
 // ResultRange
 

diff  --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h
index 06df9edbfa8b..1a6319d2b2e8 100644
--- a/mlir/include/mlir/IR/UseDefLists.h
+++ b/mlir/include/mlir/IR/UseDefLists.h
@@ -164,7 +164,8 @@ template <typename DerivedT, typename IRValueTy> class IROperand {
     other.back = nullptr;
     nextUse = nullptr;
     back = nullptr;
-    insertIntoCurrent();
+    if (value)
+      insertIntoCurrent();
     return *this;
   }
 

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 40f01e0d21ea..f51b273a9c89 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -196,15 +196,26 @@ ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
 
 /// Return the specified attribute if present, null otherwise.
 Attribute DictionaryAttr::get(StringRef name) const {
+  Optional<NamedAttribute> attr = getNamed(name);
+  return attr ? attr->second : nullptr;
+}
+Attribute DictionaryAttr::get(Identifier name) const {
+  Optional<NamedAttribute> attr = getNamed(name);
+  return attr ? attr->second : nullptr;
+}
+
+/// Return the specified named attribute if present, None otherwise.
+Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
   ArrayRef<NamedAttribute> values = getValue();
   auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName);
-  return it != values.end() && it->first == name ? it->second : Attribute();
+  return it != values.end() && it->first == name ? *it
+                                                 : Optional<NamedAttribute>();
 }
-Attribute DictionaryAttr::get(Identifier name) const {
+Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
   for (auto elt : getValue())
     if (elt.first == name)
-      return elt.second;
-  return nullptr;
+      return elt;
+  return llvm::None;
 }
 
 DictionaryAttr::iterator DictionaryAttr::begin() const {
@@ -1191,6 +1202,15 @@ Attribute MutableDictionaryAttr::get(Identifier name) const {
   return attrs ? attrs.get(name) : nullptr;
 }
 
+/// Return the specified named attribute if present, None otherwise.
+Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
+  return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
+}
+Optional<NamedAttribute>
+MutableDictionaryAttr::getNamed(Identifier name) const {
+  return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
+}
+
 /// If the an attribute exists with the specified name, change it to the new
 /// value.  Otherwise, add a new attribute with the specified name/value.
 void MutableDictionaryAttr::set(Identifier name, Attribute value) {

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index c8f2d67e3e65..5b439d67ad67 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -244,6 +244,25 @@ void Operation::setOperands(ValueRange operands) {
   assert(operands.empty() && "setting operands without an operand storage");
 }
 
+/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
+/// with the ones provided in 'operands'. 'operands' may be smaller or larger
+/// than the range pointed to by 'start'+'length'.
+void Operation::setOperands(unsigned start, unsigned length,
+                            ValueRange operands) {
+  assert((start + length) <= getNumOperands() &&
+         "invalid operand range specified");
+  if (LLVM_LIKELY(hasOperandStorage))
+    return getOperandStorage().setOperands(this, start, length, operands);
+  assert(operands.empty() && "setting operands without an operand storage");
+}
+
+/// Insert the given operands into the operand list at the given 'index'.
+void Operation::insertOperands(unsigned index, ValueRange operands) {
+  if (LLVM_LIKELY(hasOperandStorage))
+    return setOperands(index, /*length=*/0, operands);
+  assert(operands.empty() && "inserting operands without an operand storage");
+}
+
 //===----------------------------------------------------------------------===//
 // Diagnostics
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 5f698b1bdc53..087828e6e519 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -13,7 +13,9 @@
 
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/Block.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -89,6 +91,55 @@ void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
     storageOperands[i].set(values[i]);
 }
 
+/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
+/// with the ones provided in 'operands'. 'operands' may be smaller or larger
+/// than the range pointed to by 'start'+'length'.
+void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
+                                         unsigned length, ValueRange operands) {
+  // If the new size is the same, we can update inplace.
+  unsigned newSize = operands.size();
+  if (newSize == length) {
+    MutableArrayRef<OpOperand> storageOperands = getOperands();
+    for (unsigned i = 0, e = length; i != e; ++i)
+      storageOperands[start + i].set(operands[i]);
+    return;
+  }
+  // If the new size is greater, remove the extra operands and set the rest
+  // inplace.
+  if (newSize < length) {
+    eraseOperands(start + operands.size(), length - newSize);
+    setOperands(owner, start, newSize, operands);
+    return;
+  }
+  // Otherwise, the new size is greater so we need to grow the storage.
+  auto storageOperands = resize(owner, size() + (newSize - length));
+
+  // Shift operands to the right to make space for the new operands.
+  unsigned rotateSize = storageOperands.size() - (start + length);
+  auto rbegin = storageOperands.rbegin();
+  std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize);
+
+  // Update the operands inplace.
+  for (unsigned i = 0, e = operands.size(); i != e; ++i)
+    storageOperands[start + i].set(operands[i]);
+}
+
+/// Erase an operand held by the storage.
+void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
+  TrailingOperandStorage &storage = getStorage();
+  MutableArrayRef<OpOperand> operands = storage.getOperands();
+  assert((start + length) <= operands.size());
+  storage.numOperands -= length;
+
+  // Shift all operands down if the operand to remove is not at the end.
+  if (start != storage.numOperands) {
+    auto indexIt = std::next(operands.begin(), start);
+    std::rotate(indexIt, std::next(indexIt, length), operands.end());
+  }
+  for (unsigned i = 0; i != length; ++i)
+    operands[storage.numOperands + i].~OpOperand();
+}
+
 /// Resize the storage to the given size. Returns the array containing the new
 /// operands.
 MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
@@ -149,20 +200,6 @@ MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
   return newOperands;
 }
 
-/// Erase an operand held by the storage.
-void detail::OperandStorage::eraseOperand(unsigned index) {
-  assert(index < size());
-  TrailingOperandStorage &storage = getStorage();
-  MutableArrayRef<OpOperand> operands = storage.getOperands();
-  --storage.numOperands;
-
-  // Shift all operands down by 1 if the operand to remove is not at the end.
-  auto indexIt = std::next(operands.begin(), index);
-  if (index != storage.numOperands)
-    std::rotate(indexIt, std::next(indexIt), operands.end());
-  operands[storage.numOperands].~OpOperand();
-}
-
 //===----------------------------------------------------------------------===//
 // ResultStorage
 //===----------------------------------------------------------------------===//
@@ -235,6 +272,83 @@ unsigned OperandRange::getBeginOperandIndex() const {
   return base->getOperandNumber();
 }
 
+//===----------------------------------------------------------------------===//
+// MutableOperandRange
+
+/// Construct a new mutable range from the given operand, operand start index,
+/// and range length.
+MutableOperandRange::MutableOperandRange(
+    Operation *owner, unsigned start, unsigned length,
+    ArrayRef<OperandSegment> operandSegments)
+    : owner(owner), start(start), length(length),
+      operandSegments(operandSegments.begin(), operandSegments.end()) {
+  assert((start + length) <= owner->getNumOperands() && "invalid range");
+}
+MutableOperandRange::MutableOperandRange(Operation *owner)
+    : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
+
+/// Append the given values to the range.
+void MutableOperandRange::append(ValueRange values) {
+  if (values.empty())
+    return;
+  owner->insertOperands(start + length, values);
+  updateLength(length + values.size());
+}
+
+/// Assign this range to the given values.
+void MutableOperandRange::assign(ValueRange values) {
+  owner->setOperands(start, length, values);
+  if (length != values.size())
+    updateLength(/*newLength=*/values.size());
+}
+
+/// Assign the range to the given value.
+void MutableOperandRange::assign(Value value) {
+  if (length == 1) {
+    owner->setOperand(start, value);
+  } else {
+    owner->setOperands(start, length, value);
+    updateLength(/*newLength=*/1);
+  }
+}
+
+/// Erase the operands within the given sub-range.
+void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
+  assert((subStart + subLen) <= length && "invalid sub-range");
+  if (length == 0)
+    return;
+  owner->eraseOperands(start + subStart, subLen);
+  updateLength(length - subLen);
+}
+
+/// Clear this range and erase all of the operands.
+void MutableOperandRange::clear() {
+  if (length != 0) {
+    owner->eraseOperands(start, length);
+    updateLength(/*newLength=*/0);
+  }
+}
+
+/// Allow implicit conversion to an OperandRange.
+MutableOperandRange::operator OperandRange() const {
+  return owner->getOperands().slice(start, length);
+}
+
+/// Update the length of this range to the one provided.
+void MutableOperandRange::updateLength(unsigned newLength) {
+  int32_t 
diff  = int32_t(newLength) - int32_t(length);
+  length = newLength;
+
+  // Update any of the provided segment attributes.
+  for (OperandSegment &segment : operandSegments) {
+    auto attr = segment.second.second.cast<DenseIntElementsAttr>();
+    SmallVector<int32_t, 8> segments(attr.getValues<int32_t>());
+    segments[segment.first] += 
diff ;
+    segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments);
+    owner->setAttr(segment.second.first, segment.second.second);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // ResultRange
 

diff  --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 69f1cfe5d8aa..4fa2f3af29e1 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -67,6 +67,8 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   Operation::operand_range getODSOperands(unsigned index);
 // CHECK:   Value a();
 // CHECK:   Operation::operand_range b();
+// CHECK:   ::mlir::MutableOperandRange aMutable();
+// CHECK:   ::mlir::MutableOperandRange bMutable();
 // CHECK:   Operation::result_range getODSResults(unsigned index);
 // CHECK:   Value r();
 // CHECK:   Region &someRegion();
@@ -119,6 +121,7 @@ def NS_EOp : NS_Op<"op_with_optionals", []> {
 
 // CHECK-LABEL: NS::EOp declarations
 // CHECK:   Value a();
+// CHECK:   ::mlir::MutableOperandRange aMutable();
 // CHECK:   Value b();
 // CHECK:   static void build(OpBuilder &odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a)
 

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index bbfc06f883f8..35fb291498bb 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -45,25 +45,23 @@ static const char *const builderOpState = "odsState";
 // {1}: The total number of non-variadic operands/results.
 // {2}: The total number of variadic operands/results.
 // {3}: The total number of actual values.
-// {4}: The begin iterator of the actual values.
-// {5}: "operand" or "result".
+// {4}: "operand" or "result".
 const char *sameVariadicSizeValueRangeCalcCode = R"(
   bool isVariadic[] = {{{0}};
   int prevVariadicCount = 0;
   for (unsigned i = 0; i < index; ++i)
     if (isVariadic[i]) ++prevVariadicCount;
 
-  // Calculate how many dynamic values a static variadic {5} corresponds to.
-  // This assumes all static variadic {5}s have the same dynamic value count.
+  // Calculate how many dynamic values a static variadic {4} corresponds to.
+  // This assumes all static variadic {4}s have the same dynamic value count.
   int variadicSize = ({3} - {1}) / {2};
   // `index` passed in as the parameter is the static index which counts each
-  // {5} (variadic or not) as size 1. So here for each previous static variadic
-  // {5}, we need to offset by (variadicSize - 1) to get where the dynamic
-  // value pack for this static {5} starts.
-  int offset = index + (variadicSize - 1) * prevVariadicCount;
+  // {4} (variadic or not) as size 1. So here for each previous static variadic
+  // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
+  // value pack for this static {4} starts.
+  int start = index + (variadicSize - 1) * prevVariadicCount;
   int size = isVariadic[index] ? variadicSize : 1;
-
-  return {{std::next({4}, offset), std::next({4}, offset + size)};
+  return {{start, size};
 )";
 
 // The logic to calculate the actual value range for a declared operand/result
@@ -72,14 +70,23 @@ const char *sameVariadicSizeValueRangeCalcCode = R"(
 // (variadic or not).
 //
 // {0}: The name of the attribute specifying the segment sizes.
-// {1}: The begin iterator of the actual values.
 const char *attrSizedSegmentValueRangeCalcCode = R"(
   auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
   unsigned start = 0;
   for (unsigned i = 0; i < index; ++i)
     start += (*(sizeAttr.begin() + i)).getZExtValue();
-  unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue();
-  return {{std::next({1}, start), std::next({1}, end)};
+  unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
+  return {{start, size};
+)";
+
+// The logic to build a range of either operand or result values.
+//
+// {0}: The begin iterator of the actual values.
+// {1}: The call to generate the start and length of the value range.
+const char *valueRangeReturnCode = R"(
+  auto valueRange = {1};
+  return {{std::next({0}, valueRange.first),
+           std::next({0}, valueRange.first + valueRange.second)};
 )";
 
 static const char *const opCommentHeader = R"(
@@ -177,6 +184,9 @@ class OpEmitter {
   // Generates getters for named operands.
   void genNamedOperandGetters();
 
+  // Generates setters for named operands.
+  void genNamedOperandSetters();
+
   // Generates getters for named results.
   void genNamedResultGetters();
 
@@ -310,6 +320,7 @@ OpEmitter::OpEmitter(const Operator &op)
   genOpAsmInterface();
   genOpNameGetter();
   genNamedOperandGetters();
+  genNamedOperandSetters();
   genNamedResultGetters();
   genNamedRegionGetters();
   genNamedSuccessorGetters();
@@ -478,6 +489,37 @@ void OpEmitter::genAttrSetters() {
   }
 }
 
+// Generates the code to compute the start and end index of an operand or result
+// range.
+template <typename RangeT>
+static void
+generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
+                              int numVariadic, int numNonVariadic,
+                              StringRef rangeSizeCall, bool hasAttrSegmentSize,
+                              StringRef segmentSizeAttr, RangeT &&odsValues) {
+  auto &method = opClass.newMethod("std::pair<unsigned, unsigned>", methodName,
+                                   "unsigned index");
+
+  if (numVariadic == 0) {
+    method.body() << "  return {index, 1};\n";
+  } else if (hasAttrSegmentSize) {
+    method.body() << formatv(attrSizedSegmentValueRangeCalcCode,
+                             segmentSizeAttr);
+  } else {
+    // Because the op can have arbitrarily interleaved variadic and non-variadic
+    // operands, we need to embed a list in the "sink" getter method for
+    // calculation at run-time.
+    llvm::SmallVector<StringRef, 4> isVariadic;
+    isVariadic.reserve(llvm::size(odsValues));
+    for (auto &it : odsValues)
+      isVariadic.push_back(it.isVariableLength() ? "true" : "false");
+    std::string isVariadicList = llvm::join(isVariadic, ", ");
+    method.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
+                             numNonVariadic, numVariadic, rangeSizeCall,
+                             "operand");
+  }
+}
+
 // Generates the named operand getter methods for the given Operator `op` and
 // puts them in `opClass`.  Uses `rangeType` as the return type of getters that
 // return a range of operands (individual operands are `Value ` and each
@@ -519,32 +561,16 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
                     "'SameVariadicOperandSize' traits");
   }
 
-  // First emit a "sink" getter method upon which we layer all nicer named
+  // First emit a few "sink" getter methods upon which we layer all nicer named
   // getter methods.
-  auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
+  generateValueRangeStartAndEnd(
+      opClass, "getODSOperandIndexAndLength", numVariadicOperands,
+      numNormalOperands, rangeSizeCall, attrSizedOperands,
+      "operand_segment_sizes", const_cast<Operator &>(op).getOperands());
 
-  if (numVariadicOperands == 0) {
-    // We still need to match the return type, which is a range.
-    m.body() << "  return {std::next(" << rangeBeginCall
-             << ", index), std::next(" << rangeBeginCall << ", index + 1)};";
-  } else if (attrSizedOperands) {
-    m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
-                        "operand_segment_sizes", rangeBeginCall);
-  } else {
-    // Because the op can have arbitrarily interleaved variadic and non-variadic
-    // operands, we need to embed a list in the "sink" getter method for
-    // calculation at run-time.
-    llvm::SmallVector<StringRef, 4> isVariadic;
-    isVariadic.reserve(numOperands);
-    for (int i = 0; i < numOperands; ++i)
-      isVariadic.push_back(op.getOperand(i).isVariableLength() ? "true"
-                                                               : "false");
-    std::string isVariadicList = llvm::join(isVariadic, ", ");
-
-    m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
-                        numNormalOperands, numVariadicOperands, rangeSizeCall,
-                        rangeBeginCall, "operand");
-  }
+  auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
+  m.body() << formatv(valueRangeReturnCode, rangeBeginCall,
+                      "getODSOperandIndexAndLength(index)");
 
   // Then we emit nicer named getter methods by redirecting to the "sink" getter
   // method.
@@ -579,6 +605,26 @@ void OpEmitter::genNamedOperandGetters() {
       /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
 }
 
+void OpEmitter::genNamedOperandSetters() {
+  auto *attrSizedOperands = op.getTrait("OpTrait::AttrSizedOperandSegments");
+  for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
+    const auto &operand = op.getOperand(i);
+    if (operand.name.empty())
+      continue;
+    auto &m = opClass.newMethod("::mlir::MutableOperandRange",
+                                (operand.name + "Mutable").str());
+    auto &body = m.body();
+    body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
+         << "  return ::mlir::MutableOperandRange(getOperation(), "
+            "range.first, range.second";
+    if (attrSizedOperands)
+      body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
+           << "u, *getOperation()->getMutableAttrDict().getNamed("
+              "\"operand_segment_sizes\"))";
+    body << ");\n";
+  }
+}
+
 void OpEmitter::genNamedResultGetters() {
   const int numResults = op.getNumResults();
   const int numVariadicResults = op.getNumVariableLengthResults();
@@ -607,29 +653,14 @@ void OpEmitter::genNamedResultGetters() {
                     "'SameVariadicResultSize' traits");
   }
 
+  generateValueRangeStartAndEnd(
+      opClass, "getODSResultIndexAndLength", numVariadicResults,
+      numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
+      "result_segment_sizes", op.getResults());
   auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
                               "unsigned index");
-
-  if (numVariadicResults == 0) {
-    m.body() << "  return {std::next(getOperation()->result_begin(), index), "
-                "std::next(getOperation()->result_begin(), index + 1)};";
-  } else if (attrSizedResults) {
-    m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
-                        "result_segment_sizes",
-                        "getOperation()->result_begin()");
-  } else {
-    llvm::SmallVector<StringRef, 4> isVariadic;
-    isVariadic.reserve(numResults);
-    for (int i = 0; i < numResults; ++i)
-      isVariadic.push_back(op.getResult(i).isVariableLength() ? "true"
-                                                              : "false");
-    std::string isVariadicList = llvm::join(isVariadic, ", ");
-
-    m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
-                        numNormalResults, numVariadicResults,
-                        "getOperation()->getNumResults()",
-                        "getOperation()->result_begin()", "result");
-  }
+  m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
+                      "getODSResultIndexAndLength(index)");
 
   for (int i = 0; i != numResults; ++i) {
     const auto &result = op.getResult(i);

diff  --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 3cfb62553bfb..95ddcccc565e 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -33,7 +33,7 @@ TEST(OperandStorageTest, NonResizable) {
   Value operand = useOp->getResult(0);
 
   // Create a non-resizable operation with one operand.
-  Operation *user = createOp(&context, operand, builder.getIntegerType(16));
+  Operation *user = createOp(&context, operand);
 
   // The same number of operands is okay.
   user->setOperands(operand);
@@ -57,7 +57,7 @@ TEST(OperandStorageTest, Resizable) {
   Value operand = useOp->getResult(0);
 
   // Create a resizable operation with one operand.
-  Operation *user = createOp(&context, operand, builder.getIntegerType(16));
+  Operation *user = createOp(&context, operand);
 
   // The same number of operands is okay.
   user->setOperands(operand);
@@ -76,4 +76,77 @@ TEST(OperandStorageTest, Resizable) {
   useOp->destroy();
 }
 
+TEST(OperandStorageTest, RangeReplace) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  Operation *useOp =
+      createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16));
+  Value operand = useOp->getResult(0);
+
+  // Create a resizable operation with one operand.
+  Operation *user = createOp(&context, operand);
+
+  // Check setting with the same number of operands.
+  user->setOperands(/*start=*/0, /*length=*/1, operand);
+  EXPECT_EQ(user->getNumOperands(), 1u);
+
+  // Check setting with more operands.
+  user->setOperands(/*start=*/0, /*length=*/1, {operand, operand, operand});
+  EXPECT_EQ(user->getNumOperands(), 3u);
+
+  // Check setting with less operands.
+  user->setOperands(/*start=*/1, /*length=*/2, {operand});
+  EXPECT_EQ(user->getNumOperands(), 2u);
+
+  // Check inserting without replacing operands.
+  user->setOperands(/*start=*/2, /*length=*/0, {operand});
+  EXPECT_EQ(user->getNumOperands(), 3u);
+
+  // Check erasing operands.
+  user->setOperands(/*start=*/0, /*length=*/3, {});
+  EXPECT_EQ(user->getNumOperands(), 0u);
+
+  // Destroy the operations.
+  user->destroy();
+  useOp->destroy();
+}
+
+TEST(OperandStorageTest, MutableRange) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  Operation *useOp =
+      createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16));
+  Value operand = useOp->getResult(0);
+
+  // Create a resizable operation with one operand.
+  Operation *user = createOp(&context, operand);
+
+  // Check setting with the same number of operands.
+  MutableOperandRange mutableOperands(user);
+  mutableOperands.assign(operand);
+  EXPECT_EQ(mutableOperands.size(), 1u);
+  EXPECT_EQ(user->getNumOperands(), 1u);
+
+  // Check setting with more operands.
+  mutableOperands.assign({operand, operand, operand});
+  EXPECT_EQ(mutableOperands.size(), 3u);
+  EXPECT_EQ(user->getNumOperands(), 3u);
+
+  // Check with inserting a new operand.
+  mutableOperands.append({operand, operand});
+  EXPECT_EQ(mutableOperands.size(), 5u);
+  EXPECT_EQ(user->getNumOperands(), 5u);
+
+  // Check erasing operands.
+  mutableOperands.clear();
+  EXPECT_EQ(mutableOperands.size(), 0u);
+  EXPECT_EQ(user->getNumOperands(), 0u);
+
+  // Destroy the operations.
+  user->destroy();
+  useOp->destroy();
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list