[Mlir-commits] [mlir] [mlir][TblGen] get...Mutable returns OpOperand & for single operands (PR #66519)
Matthias Springer
llvmlistbot at llvm.org
Fri Sep 15 08:15:39 PDT 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/66519
The TableGen code generator now generates C++ code that returns a single `OpOperand &` for `get...Mutable` of operands that are not variadic and not optional. `OpOperand::set`/`assign` can be used to set a value (same as `MutableOperandRange::assign`). It is safer than `MutableOperandRange` only single values (and no longer `ValueRange`) can be assigned.
E.g.:
```
// Before: Assign multiple values to non-variadic operand (forbidden, but
// compiles).
// After: Compilation error.
extractSliceOp.getSourceMutable().assign({v1, v2});
```
Depends on #66515. Review only the top commit.
>From 526e7257bfbd5d2bf17f80abd78c6687922a9e05 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 15 Sep 2023 17:11:25 +0200
Subject: [PATCH 1/2] [mlir][IR] MutableOperandRange: `operator[]` returns
`OpOperand &`
`operator[]` returns `OpOperand &` instead of `Value`.
* This allows users to get OpOperands by name instead of "magic" number. E.g., `extractSliceOp->getOpOperand(0)` can be written as `extractSliceOp.getSourceMutable()[0]`.
* `OperandRange` provides a read-only API to operands: `operator[]` returns `Value`. `MutableOperandRange` now provides a mutable API: `operator[]` returns `OpOperand &`, which can be used to set operands.
Note: The TableGen code generator could be changed to return `OpOperand &` (instead of `MutableOperandRange`) for non-variadic and non-optional arguments in a subsequent change. Then the `[0]` part in the above example would no longer be necessary.
---
mlir/include/mlir/IR/ValueRange.h | 6 ++----
mlir/include/mlir/Interfaces/ControlFlowInterfaces.h | 2 +-
.../Dialect/Bufferization/IR/BufferizationOps.cpp | 10 +++-------
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 2 +-
.../Dialect/SCF/Transforms/TileUsingInterface.cpp | 2 +-
.../Transforms/BufferizableOpInterfaceImpl.cpp | 12 ++++++------
mlir/lib/IR/OperationSupport.cpp | 4 ++++
mlir/lib/Transforms/Utils/CFGToSCF.cpp | 3 ++-
8 files changed, 20 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index 187185b47b66695..f1a1f1841f179e7 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -162,10 +162,8 @@ class MutableOperandRange {
/// elements attribute, which contains the sizes of the sub ranges.
MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
- /// Returns the value at the given index.
- Value operator[](unsigned index) const {
- return operator OperandRange()[index];
- }
+ /// Returns the OpOperand at the given index.
+ OpOperand &operator[](unsigned index) const;
OperandRange::iterator begin() const {
return static_cast<OperandRange>(*this).begin();
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 006aedced839f99..7f6967f11444f31 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -76,7 +76,7 @@ class SuccessorOperands {
Value operator[](unsigned index) const {
if (isOperandProduced(index))
return Value();
- return forwardedOperands[index - producedOperandCount];
+ return forwardedOperands[index - producedOperandCount].get();
}
/// Get the range of operands that are simply forwarded to the successor.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index e5016c956804688..3a30f1a1405ec11 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -549,22 +549,18 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
bool MaterializeInDestinationOp::bufferizesToMemoryRead(
OpOperand &opOperand, const AnalysisState &state) {
- if (&opOperand == &getOperation()->getOpOperand(0) /*source*/)
- return true;
- return false;
+ return &opOperand == &getSourceMutable()[0];
}
bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
OpOperand &opOperand, const AnalysisState &state) {
- if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
- return true;
- return false;
+ return &opOperand == &getDestMutable()[0];
}
AliasingValueList
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
const AnalysisState &state) {
- if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
+ if (&opOperand == &getDestMutable()[0])
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
return {};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 581e7b0a8ea86a7..6a01c24f026990f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -949,7 +949,7 @@ struct FoldReshapeWithGenericOpByExpansion
reshapeOp, "failed preconditions of fusion with producer generic op");
}
- if (!controlFoldingReshapes(&reshapeOp->getOpOperand(0))) {
+ if (!controlFoldingReshapes(&reshapeOp.getSrcMutable()[0])) {
return rewriter.notifyMatchFailure(reshapeOp,
"fusion blocked by control function");
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 597676a017bf482..1ce25565edcaf61 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -509,7 +509,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
// 1. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
auto [fusableProducer, destinationIterArg] =
- getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
+ getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
loops);
if (!fusableProducer)
return std::nullopt;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index ecca4dd3394e0ae..ef4352cf0c6592e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
RankedTensorType destType = insertSliceOp.getDestType();
// The source is always read.
- if (&opOperand == &op->getOpOperand(0) /*src*/)
+ if (&opOperand == &insertSliceOp.getSourceMutable()[0])
return true;
// For the destination, it depends...
- assert(&opOperand == &insertSliceOp->getOpOperand(1) && "expected dest");
+ assert(&opOperand == &insertSliceOp.getDestMutable()[0] && "expected dest");
// Dest is not read if it is entirely overwritten. E.g.:
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -851,9 +851,8 @@ struct ReshapeOpInterface
tensor::ReshapeOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- if (&opOperand == &op->getOpOperand(1) /* shape */)
- return true;
- return false;
+ auto reshapeOp = cast<tensor::ReshapeOp>(op);
+ return &opOperand == &reshapeOp.getShapeMutable()[0];
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -915,7 +914,8 @@ struct ParallelInsertSliceOpInterface
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- return &opOperand == &op->getOpOperand(1) /*dest*/;
+ auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
+ return &opOperand == ¶llelInsertSliceOp.getDestMutable()[0];
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 0cb6a1cd191b161..8b8eeabf38f476f 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -517,6 +517,10 @@ void MutableOperandRange::updateLength(unsigned newLength) {
}
}
+OpOperand &MutableOperandRange::operator[](unsigned index) const {
+ return owner->getOpOperand(start + index);
+}
+
//===----------------------------------------------------------------------===//
// MutableOperandRangeRange
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index 84f23584e9f30e3..9aab89ed7553600 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -277,7 +277,8 @@ class EdgeMultiplexer {
if (index >= result->second &&
index < result->second + edge.getSuccessor()->getNumArguments()) {
// Original block arguments to the entry block.
- newSuccOperands[index] = successorOperands[index - result->second];
+ newSuccOperands[index] =
+ successorOperands[index - result->second].get();
continue;
}
>From 24f34315475d845b302f858e349640e167214229 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 15 Sep 2023 17:11:42 +0200
Subject: [PATCH 2/2] [mlir][TblGen] `get...Mutable` returns `OpOperand &` for
single operands
The TableGen code generator now generates C++ code that returns a single `OpOperand &` for `get...Mutable` of operands that are not variadic and not optional. `OpOperand::set`/`assign` can be used to set a value (same as `MutableOperandRange::assign`). It is safer than `MutableOperandRange` only single values (and no longer `ValueRange`) can be assigned.
E.g.:
```
// Before: Assign multiple values to non-variadic operand (forbidden, but
// compiles).
// After: Compilation error.
extractSliceOp.getSourceMutable().assign({v1, v2});
```
---
mlir/include/mlir/IR/Value.h | 3 +
mlir/include/mlir/IR/ValueRange.h | 3 +
.../Bufferization/IR/BufferizationOps.cpp | 6 +-
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 2 +-
.../SCF/Transforms/TileUsingInterface.cpp | 2 +-
.../BufferizableOpInterfaceImpl.cpp | 8 +--
mlir/lib/IR/OperationSupport.cpp | 6 ++
mlir/test/lib/Dialect/Test/TestDialect.cpp | 2 +-
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 55 +++++++++++--------
9 files changed, 55 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 51d4e366e4970d5..4e550fe3e3a60e6 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -268,6 +268,9 @@ class OpOperand : public IROperand<OpOperand, Value> {
/// Return which operand this is in the OpOperand list of the Operation.
unsigned getOperandNumber();
+ /// Set the current value being used by this operand.
+ void assign(Value value) { set(value); }
+
private:
/// Keep the constructor private and accessible to the OperandStorage class
/// only to avoid hard-to-debug typo/programming mistakes.
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index f1a1f1841f179e7..4546f0fe4bf48c5 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -126,6 +126,9 @@ class MutableOperandRange {
ArrayRef<OperandSegment> operandSegments = std::nullopt);
MutableOperandRange(Operation *owner);
+ /// Construct a new mutable range for the given OpOperand.
+ MutableOperandRange(OpOperand &opOperand);
+
/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange
slice(unsigned subStart, unsigned subLen,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 3a30f1a1405ec11..59ec8ccc0806f6c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -549,18 +549,18 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
bool MaterializeInDestinationOp::bufferizesToMemoryRead(
OpOperand &opOperand, const AnalysisState &state) {
- return &opOperand == &getSourceMutable()[0];
+ return &opOperand == &getSourceMutable();
}
bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
OpOperand &opOperand, const AnalysisState &state) {
- return &opOperand == &getDestMutable()[0];
+ return &opOperand == &getDestMutable();
}
AliasingValueList
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
const AnalysisState &state) {
- if (&opOperand == &getDestMutable()[0])
+ if (&opOperand == &getDestMutable())
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
return {};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6a01c24f026990f..f704a5235571183 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -949,7 +949,7 @@ struct FoldReshapeWithGenericOpByExpansion
reshapeOp, "failed preconditions of fusion with producer generic op");
}
- if (!controlFoldingReshapes(&reshapeOp.getSrcMutable()[0])) {
+ if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(reshapeOp,
"fusion blocked by control function");
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1ce25565edcaf61..6931d386261967d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -509,7 +509,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
// 1. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
auto [fusableProducer, destinationIterArg] =
- getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
+ getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
loops);
if (!fusableProducer)
return std::nullopt;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index ef4352cf0c6592e..ec7a06fd8891710 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
RankedTensorType destType = insertSliceOp.getDestType();
// The source is always read.
- if (&opOperand == &insertSliceOp.getSourceMutable()[0])
+ if (&opOperand == &insertSliceOp.getSourceMutable())
return true;
// For the destination, it depends...
- assert(&opOperand == &insertSliceOp.getDestMutable()[0] && "expected dest");
+ assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
// Dest is not read if it is entirely overwritten. E.g.:
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -852,7 +852,7 @@ struct ReshapeOpInterface
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto reshapeOp = cast<tensor::ReshapeOp>(op);
- return &opOperand == &reshapeOp.getShapeMutable()[0];
+ return &opOperand == &reshapeOp.getShapeMutable();
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -915,7 +915,7 @@ struct ParallelInsertSliceOpInterface
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
- return &opOperand == ¶llelInsertSliceOp.getDestMutable()[0];
+ return &opOperand == ¶llelInsertSliceOp.getDestMutable();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 8b8eeabf38f476f..a9b55cec7659c55 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -437,6 +437,12 @@ MutableOperandRange::MutableOperandRange(
MutableOperandRange::MutableOperandRange(Operation *owner)
: MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
+/// Construct a new mutable range for the given OpOperand.
+MutableOperandRange::MutableOperandRange(OpOperand &opOperand)
+ : MutableOperandRange(opOperand.getOwner(),
+ /*start=*/opOperand.getOperandNumber(),
+ /*length=*/1) {}
+
/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange
MutableOperandRange::slice(unsigned subStart, unsigned subLen,
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 00c251936655d71..e3d86b4a44d0001 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -998,7 +998,7 @@ void LoopBlockOp::getSuccessorRegions(
OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBody());
- return getInitMutable();
+ return MutableOperandRange(getInitMutable());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index ad4f53c5af3cff4..df1d13d3bf5580d 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2071,29 +2071,36 @@ void OpEmitter::genNamedOperandSetters() {
continue;
std::string name = op.getGetterName(operand.name);
- auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
- ? "::mlir::MutableOperandRangeRange"
- : "::mlir::MutableOperandRange",
- name + "Mutable");
+ StringRef returnType;
+ if (operand.isVariadicOfVariadic()) {
+ returnType = "::mlir::MutableOperandRangeRange";
+ } else if (operand.isVariableLength()) {
+ returnType = "::mlir::MutableOperandRange";
+ } else {
+ returnType = "::mlir::OpOperand &";
+ }
+ auto *m = opClass.addMethod(returnType, name + "Mutable");
ERROR_IF_PRUNED(m, name, op);
auto &body = m->body();
- body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
- << " auto mutableRange = "
- "::mlir::MutableOperandRange(getOperation(), "
- "range.first, range.second";
- if (attrSizedOperands) {
- if (emitHelper.hasProperties())
- body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
- "{{getOperandSegmentSizesAttrName(), "
- "::mlir::DenseI32ArrayAttr::get(getContext(), "
- "getProperties().operandSegmentSizes)})",
- i);
- else
- body << formatv(
- ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
- emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
+ body << " auto range = getODSOperandIndexAndLength(" << i << ");\n";
+ if (operand.isVariadicOfVariadic() || operand.isVariableLength()) {
+ body << " auto mutableRange = "
+ "::mlir::MutableOperandRange(getOperation(), "
+ "range.first, range.second";
+ if (attrSizedOperands) {
+ if (emitHelper.hasProperties())
+ body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
+ "{{getOperandSegmentSizesAttrName(), "
+ "::mlir::DenseI32ArrayAttr::get(getContext(), "
+ "getProperties().operandSegmentSizes)})",
+ i);
+ else
+ body << formatv(
+ ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
+ emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
+ }
+ body << ");\n";
}
- body << ");\n";
// If this operand is a nested variadic, we split the range into a
// MutableOperandRangeRange that provides a range over all of the
@@ -2104,9 +2111,13 @@ void OpEmitter::genNamedOperandSetters() {
<< op.getGetterName(
operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
<< "AttrName()));\n";
- } else {
- // Otherwise, we use the full range directly.
+ } else if (operand.isVariableLength()) {
+ // Otherwise, if the operand has variable length, we use the full range
+ // directly.
body << " return mutableRange;\n";
+ } else {
+ // In case of a single operand, return a single OpOperand.
+ body << " return getOperation()->getOpOperand(range.first);\n";
}
}
}
More information about the Mlir-commits
mailing list