[Mlir-commits] [mlir] 8823e96 - [mlir][ODS] Change `get...Mutable` to return `OpOperand &` for single operands (#66519)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 3 23:35:45 PDT 2023


Author: Matthias Springer
Date: 2023-10-04T08:35:40+02:00
New Revision: 8823e961f6a41854668d2dc1a1fc787cfa85ca43

URL: https://github.com/llvm/llvm-project/commit/8823e961f6a41854668d2dc1a1fc787cfa85ca43
DIFF: https://github.com/llvm/llvm-project/commit/8823e961f6a41854668d2dc1a1fc787cfa85ca43.diff

LOG: [mlir][ODS] Change `get...Mutable` to return `OpOperand &` for single operands (#66519)

The TableGen code generator now generates C++ code that returns a single
`OpOperand &` for `get...Mutable` of operands that are not variadic and
not optional. `OpOperand::set`/`assign` can be used to set a value (same
as `MutableOperandRange::assign`). This is safer than
`MutableOperandRange` because only single values (and no longer
`ValueRange`) can be assigned.

E.g.:
```
// Assignment of multiple values to non-variadic operand.
// Before: Compiles, but produces invalid op.
// After: Compilation error.
extractSliceOp.getSourceMutable().assign({v1, v2});
```

Added: 
    

Modified: 
    mlir/include/mlir/IR/Value.h
    mlir/include/mlir/IR/ValueRange.h
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
    mlir/lib/IR/OperationSupport.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/mlir-tblgen/op-decl-and-defs.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 51d4e366e4970d5..4e550fe3e3a60e6 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -268,6 +268,9 @@ class OpOperand : public IROperand<OpOperand, Value> {
   /// Return which operand this is in the OpOperand list of the Operation.
   unsigned getOperandNumber();
 
+  /// Set the current value being used by this operand.
+  void assign(Value value) { set(value); }
+
 private:
   /// Keep the constructor private and accessible to the OperandStorage class
   /// only to avoid hard-to-debug typo/programming mistakes.

diff  --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index 9c11178f9cd9cae..ed69e5824f70b51 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -126,6 +126,9 @@ class MutableOperandRange {
                       ArrayRef<OperandSegment> operandSegments = std::nullopt);
   MutableOperandRange(Operation *owner);
 
+  /// Construct a new mutable range for the given OpOperand.
+  MutableOperandRange(OpOperand &opOperand);
+
   /// Slice this range into a sub range, with the additional operand segment.
   MutableOperandRange
   slice(unsigned subStart, unsigned subLen,

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index d2823d17c99c88c..01cbacc96fd42d2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -537,18 +537,18 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
 
 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
     OpOperand &opOperand, const AnalysisState &state) {
-  return &opOperand == &getSourceMutable()[0];
+  return &opOperand == &getSourceMutable();
 }
 
 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
     OpOperand &opOperand, const AnalysisState &state) {
-  return &opOperand == &getDestMutable()[0];
+  return &opOperand == &getDestMutable();
 }
 
 AliasingValueList
 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
                                               const AnalysisState &state) {
-  if (&opOperand == &getDestMutable()[0])
+  if (&opOperand == &getDestMutable())
     return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
   return {};
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 17346607fa9cd7f..069c613cc246d6a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -949,7 +949,7 @@ struct FoldReshapeWithGenericOpByExpansion
           reshapeOp, "failed preconditions of fusion with producer generic op");
     }
 
-    if (!controlFoldingReshapes(&reshapeOp.getSrcMutable()[0])) {
+    if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
       return rewriter.notifyMatchFailure(reshapeOp,
                                          "fusion blocked by control function");
     }

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1ecfe7906c571f9..96d6169111b3856 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -526,7 +526,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   // 1. Get the producer of the source (potentially walking through
   // `iter_args` of nested `scf.for`)
   auto [fusableProducer, destinationInitArg] =
-      getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
+      getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
                                         loops);
   if (!fusableProducer)
     return std::nullopt;

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index e082e2c54ef36e8..17e0cc08f8d208b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -636,11 +636,11 @@ struct InsertSliceOpInterface
     RankedTensorType destType = insertSliceOp.getDestType();
 
     // The source is always read.
-    if (&opOperand == &insertSliceOp.getSourceMutable()[0])
+    if (&opOperand == &insertSliceOp.getSourceMutable())
       return true;
 
     // For the destination, it depends...
-    assert(&opOperand == &insertSliceOp.getDestMutable()[0] && "expected dest");
+    assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
 
     // Dest is not read if it is entirely overwritten. E.g.:
     // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -840,7 +840,7 @@ struct ReshapeOpInterface
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
     auto reshapeOp = cast<tensor::ReshapeOp>(op);
-    return &opOperand == &reshapeOp.getShapeMutable()[0];
+    return &opOperand == &reshapeOp.getShapeMutable();
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -917,7 +917,7 @@ struct ParallelInsertSliceOpInterface
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
-    return &opOperand == &parallelInsertSliceOp.getDestMutable()[0];
+    return &opOperand == &parallelInsertSliceOp.getDestMutable();
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index dff9f64169d495d..f4f46d54d78e59f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -63,7 +63,7 @@ struct InsertSliceOpInterface
     : public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
                                                        tensor::InsertSliceOp> {
   OpOperand &getSourceOperand(Operation *op) const {
-    return op->getOpOperand(0);
+    return cast<tensor::InsertSliceOp>(op).getSourceMutable();
   }
 
   bool
@@ -91,11 +91,11 @@ struct ParallelInsertSliceOpInterface
     : public SubsetInsertionOpInterface::ExternalModel<
           ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
   OpOperand &getSourceOperand(Operation *op) const {
-    return op->getOpOperand(0);
+    return cast<tensor::ParallelInsertSliceOp>(op).getSourceMutable();
   }
 
   OpOperand &getDestinationOperand(Operation *op) const {
-    return op->getOpOperand(1);
+    return cast<tensor::ParallelInsertSliceOp>(op).getDestMutable();
   }
 
   bool

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 5aac72fcb062c39..6726b49dd3d3103 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -437,6 +437,12 @@ MutableOperandRange::MutableOperandRange(
 MutableOperandRange::MutableOperandRange(Operation *owner)
     : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
 
+/// Construct a new mutable range for the given OpOperand.
+MutableOperandRange::MutableOperandRange(OpOperand &opOperand)
+    : MutableOperandRange(opOperand.getOwner(),
+                          /*start=*/opOperand.getOperandNumber(),
+                          /*length=*/1) {}
+
 /// Slice this range into a sub range, with the additional operand segment.
 MutableOperandRange
 MutableOperandRange::slice(unsigned subStart, unsigned subLen,

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 3f1bb667a0aecb5..8aa7774d2bf008f 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -998,7 +998,7 @@ void LoopBlockOp::getSuccessorRegions(
 
 OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
   assert(point == getBody());
-  return getInitMutable();
+  return MutableOperandRange(getInitMutable());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index 2afde1abdb72660..85d68cc42ccbd70 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -97,7 +97,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   ::mlir::Operation::operand_range getODSOperands(unsigned index);
 // CHECK:   ::mlir::TypedValue<::mlir::IntegerType> getA();
 // CHECK:   ::mlir::Operation::operand_range getB();
-// CHECK:   ::mlir::MutableOperandRange getAMutable();
+// CHECK:   ::mlir::OpOperand &getAMutable();
 // CHECK:   ::mlir::MutableOperandRange getBMutable();
 // CHECK:   ::mlir::Operation::result_range getODSResults(unsigned index);
 // CHECK:   ::mlir::TypedValue<::mlir::IntegerType> getR();

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 5f6a4e3bc52a840..7029c0eac15c392 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2071,14 +2071,26 @@ void OpEmitter::genNamedOperandSetters() {
       continue;
     std::string name = op.getGetterName(operand.name);
 
-    auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
-                                    ? "::mlir::MutableOperandRangeRange"
-                                    : "::mlir::MutableOperandRange",
-                                name + "Mutable");
+    StringRef returnType;
+    if (operand.isVariadicOfVariadic()) {
+      returnType = "::mlir::MutableOperandRangeRange";
+    } else if (operand.isVariableLength()) {
+      returnType = "::mlir::MutableOperandRange";
+    } else {
+      returnType = "::mlir::OpOperand &";
+    }
+    auto *m = opClass.addMethod(returnType, name + "Mutable");
     ERROR_IF_PRUNED(m, name, op);
     auto &body = m->body();
-    body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
-         << "  auto mutableRange = "
+    body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n";
+
+    if (!operand.isVariadicOfVariadic() && !operand.isVariableLength()) {
+      // In case of a single operand, return a single OpOperand.
+      body << "  return getOperation()->getOpOperand(range.first);\n";
+      continue;
+    }
+
+    body << "  auto mutableRange = "
             "::mlir::MutableOperandRange(getOperation(), "
             "range.first, range.second";
     if (attrSizedOperands) {


        


More information about the Mlir-commits mailing list