[Mlir-commits] [mlir] [mlir][TblGen] get...Mutable returns OpOperand & for single operands (PR #66519)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 15 08:16:51 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

<details>
<summary>Changes</summary>

The TableGen code generator now generates C++ code that returns a single `OpOperand &` for `get...Mutable` of operands that are not variadic and not optional. `OpOperand::set`/`assign`  can be used to set a value (same as `MutableOperandRange::assign`). It is safer than `MutableOperandRange` because only single values (and no longer `ValueRange`) can be assigned.

E.g.:
```
// Before: Assign multiple values to non-variadic operand (forbidden, but
//         compiles).
// After: Compilation error.
extractSliceOp.getSourceMutable().assign({v1, v2});
```

Depends on #66515. Review only the top commit.

---
Full diff: https://github.com/llvm/llvm-project/pull/66519.diff


11 Files Affected:

- (modified) mlir/include/mlir/IR/Value.h (+3) 
- (modified) mlir/include/mlir/IR/ValueRange.h (+5-4) 
- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+1-1) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+3-7) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+1-1) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+6-6) 
- (modified) mlir/lib/IR/OperationSupport.cpp (+10) 
- (modified) mlir/lib/Transforms/Utils/CFGToSCF.cpp (+2-1) 
- (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (+1-1) 
- (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+33-22) 


``````````diff
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 51d4e366e4970d5..4e550fe3e3a60e6 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -268,6 +268,9 @@ class OpOperand : public IROperand<OpOperand, Value> {
   /// Return which operand this is in the OpOperand list of the Operation.
   unsigned getOperandNumber();
 
+  /// Set the current value being used by this operand.
+  void assign(Value value) { set(value); }
+
 private:
   /// Keep the constructor private and accessible to the OperandStorage class
   /// only to avoid hard-to-debug typo/programming mistakes.
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index 187185b47b66695..4546f0fe4bf48c5 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -126,6 +126,9 @@ class MutableOperandRange {
                       ArrayRef<OperandSegment> operandSegments = std::nullopt);
   MutableOperandRange(Operation *owner);
 
+  /// Construct a new mutable range for the given OpOperand.
+  MutableOperandRange(OpOperand &opOperand);
+
   /// Slice this range into a sub range, with the additional operand segment.
   MutableOperandRange
   slice(unsigned subStart, unsigned subLen,
@@ -162,10 +165,8 @@ class MutableOperandRange {
   /// elements attribute, which contains the sizes of the sub ranges.
   MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
 
-  /// Returns the value at the given index.
-  Value operator[](unsigned index) const {
-    return operator OperandRange()[index];
-  }
+  /// Returns the OpOperand at the given index.
+  OpOperand &operator[](unsigned index) const;
 
   OperandRange::iterator begin() const {
     return static_cast<OperandRange>(*this).begin();
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 006aedced839f99..7f6967f11444f31 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -76,7 +76,7 @@ class SuccessorOperands {
   Value operator[](unsigned index) const {
     if (isOperandProduced(index))
       return Value();
-    return forwardedOperands[index - producedOperandCount];
+    return forwardedOperands[index - producedOperandCount].get();
   }
 
   /// Get the range of operands that are simply forwarded to the successor.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index e5016c956804688..59ec8ccc0806f6c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -549,22 +549,18 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
 
 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
     OpOperand &opOperand, const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(0) /*source*/)
-    return true;
-  return false;
+  return &opOperand == &getSourceMutable();
 }
 
 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
     OpOperand &opOperand, const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
-    return true;
-  return false;
+  return &opOperand == &getDestMutable();
 }
 
 AliasingValueList
 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
                                               const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
+  if (&opOperand == &getDestMutable())
     return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
   return {};
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 581e7b0a8ea86a7..f704a5235571183 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -949,7 +949,7 @@ struct FoldReshapeWithGenericOpByExpansion
           reshapeOp, "failed preconditions of fusion with producer generic op");
     }
 
-    if (!controlFoldingReshapes(&reshapeOp->getOpOperand(0))) {
+    if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
       return rewriter.notifyMatchFailure(reshapeOp,
                                          "fusion blocked by control function");
     }
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 597676a017bf482..6931d386261967d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -509,7 +509,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   // 1. Get the producer of the source (potentially walking through
   // `iter_args` of nested `scf.for`)
   auto [fusableProducer, destinationIterArg] =
-      getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
+      getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
                                         loops);
   if (!fusableProducer)
     return std::nullopt;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index ecca4dd3394e0ae..ec7a06fd8891710 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
     RankedTensorType destType = insertSliceOp.getDestType();
 
     // The source is always read.
-    if (&opOperand == &op->getOpOperand(0) /*src*/)
+    if (&opOperand == &insertSliceOp.getSourceMutable())
       return true;
 
     // For the destination, it depends...
-    assert(&opOperand == &insertSliceOp->getOpOperand(1) && "expected dest");
+    assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
 
     // Dest is not read if it is entirely overwritten. E.g.:
     // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -851,9 +851,8 @@ struct ReshapeOpInterface
                                                     tensor::ReshapeOp> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
-    if (&opOperand == &op->getOpOperand(1) /* shape */)
-      return true;
-    return false;
+    auto reshapeOp = cast<tensor::ReshapeOp>(op);
+    return &opOperand == &reshapeOp.getShapeMutable();
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -915,7 +914,8 @@ struct ParallelInsertSliceOpInterface
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
-    return &opOperand == &op->getOpOperand(1) /*dest*/;
+    auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
+    return &opOperand == &parallelInsertSliceOp.getDestMutable();
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 0cb6a1cd191b161..a9b55cec7659c55 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -437,6 +437,12 @@ MutableOperandRange::MutableOperandRange(
 MutableOperandRange::MutableOperandRange(Operation *owner)
     : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
 
+/// Construct a new mutable range for the given OpOperand.
+MutableOperandRange::MutableOperandRange(OpOperand &opOperand)
+    : MutableOperandRange(opOperand.getOwner(),
+                          /*start=*/opOperand.getOperandNumber(),
+                          /*length=*/1) {}
+
 /// Slice this range into a sub range, with the additional operand segment.
 MutableOperandRange
 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
@@ -517,6 +523,10 @@ void MutableOperandRange::updateLength(unsigned newLength) {
   }
 }
 
+OpOperand &MutableOperandRange::operator[](unsigned index) const {
+  return owner->getOpOperand(start + index);
+}
+
 //===----------------------------------------------------------------------===//
 // MutableOperandRangeRange
 
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index 84f23584e9f30e3..9aab89ed7553600 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -277,7 +277,8 @@ class EdgeMultiplexer {
       if (index >= result->second &&
           index < result->second + edge.getSuccessor()->getNumArguments()) {
         // Original block arguments to the entry block.
-        newSuccOperands[index] = successorOperands[index - result->second];
+        newSuccOperands[index] =
+            successorOperands[index - result->second].get();
         continue;
       }
 
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 00c251936655d71..e3d86b4a44d0001 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -998,7 +998,7 @@ void LoopBlockOp::getSuccessorRegions(
 
 OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
   assert(point == getBody());
-  return getInitMutable();
+  return MutableOperandRange(getInitMutable());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index ad4f53c5af3cff4..df1d13d3bf5580d 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2071,29 +2071,36 @@ void OpEmitter::genNamedOperandSetters() {
       continue;
     std::string name = op.getGetterName(operand.name);
 
-    auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
-                                    ? "::mlir::MutableOperandRangeRange"
-                                    : "::mlir::MutableOperandRange",
-                                name + "Mutable");
+    StringRef returnType;
+    if (operand.isVariadicOfVariadic()) {
+      returnType = "::mlir::MutableOperandRangeRange";
+    } else if (operand.isVariableLength()) {
+      returnType = "::mlir::MutableOperandRange";
+    } else {
+      returnType = "::mlir::OpOperand &";
+    }
+    auto *m = opClass.addMethod(returnType, name + "Mutable");
     ERROR_IF_PRUNED(m, name, op);
     auto &body = m->body();
-    body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
-         << "  auto mutableRange = "
-            "::mlir::MutableOperandRange(getOperation(), "
-            "range.first, range.second";
-    if (attrSizedOperands) {
-      if (emitHelper.hasProperties())
-        body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
-                        "{{getOperandSegmentSizesAttrName(), "
-                        "::mlir::DenseI32ArrayAttr::get(getContext(), "
-                        "getProperties().operandSegmentSizes)})",
-                        i);
-      else
-        body << formatv(
-            ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
-            emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
+    body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n";
+    if (operand.isVariadicOfVariadic() || operand.isVariableLength()) {
+      body << "  auto mutableRange = "
+              "::mlir::MutableOperandRange(getOperation(), "
+              "range.first, range.second";
+      if (attrSizedOperands) {
+        if (emitHelper.hasProperties())
+          body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
+                          "{{getOperandSegmentSizesAttrName(), "
+                          "::mlir::DenseI32ArrayAttr::get(getContext(), "
+                          "getProperties().operandSegmentSizes)})",
+                          i);
+        else
+          body << formatv(
+              ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
+              emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
+      }
+      body << ");\n";
     }
-    body << ");\n";
 
     // If this operand is a nested variadic, we split the range into a
     // MutableOperandRangeRange that provides a range over all of the
@@ -2104,9 +2111,13 @@ void OpEmitter::genNamedOperandSetters() {
            << op.getGetterName(
                   operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
            << "AttrName()));\n";
-    } else {
-      // Otherwise, we use the full range directly.
+    } else if (operand.isVariableLength()) {
+      // Otherwise, if the operand has variable length, we use the full range
+      // directly.
       body << "  return mutableRange;\n";
+    } else {
+      // In case of a single operand, return a single OpOperand.
+      body << "  return getOperation()->getOpOperand(range.first);\n";
     }
   }
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/66519


More information about the Mlir-commits mailing list