[Mlir-commits] [mlir] [mlir][Interfaces] Change `getDpsInitsMutable` to return `MutableArrayRef` (PR #69145)
Matthias Springer
llvmlistbot at llvm.org
Sun Oct 15 22:29:23 PDT 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/69145
`getDpsInitsMutable` now returns a `MutableArrayRef`. This is so that ops can implement the `DestinationStyleOpInterface` even if they do not have any "inits". An example for such an op is `vector.transfer_read`. The current implementation returns a `MutableOperandRange` with range 0 and length 0. This is problematic because the API could be misused to append operands, which would create an invalid op. `MutableArrayRef<OpOperand>` is a better abstraction, which does not allow users to change the number of operands.
>From e82ef29f8f90718b36d940ad5cf178a4e6d594ef Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 16 Oct 2023 14:27:21 +0900
Subject: [PATCH] [mlir][Interfaces] Change `getDpsInitsMutable` to return
`MutableArrayRef`
`getDpsInitsMutable` now returns a `MutableArrayRef`. This is so that ops can implement the `DestinationStyleOpInterface` even if they do not have any "inits". An example for such an op is `vector.transfer_read`. The current implementation returns a `MutableOperandRange` with range 0 and length 0. This is problematic because the API could be misused to append operands, which would create an invalid op. `MutableArrayRef<OpOperand>` is a better abstraction, which does not allow users to change the number of operands.
---
.../Bufferization/IR/BufferizationOps.td | 5 ++---
.../mlir/Dialect/Linalg/IR/LinalgOps.td | 4 +++-
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 14 +++++++++-----
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 4 +++-
.../mlir/Dialect/Tensor/IR/TensorOps.td | 6 +++---
.../mlir/Dialect/Vector/IR/VectorOps.td | 8 +++++---
mlir/include/mlir/IR/ValueRange.h | 3 +++
.../Interfaces/DestinationStyleOpInterface.td | 18 ++++++++++++------
.../Bufferization/IR/BufferizationOps.cpp | 2 +-
mlir/lib/IR/OperationSupport.cpp | 8 ++++++--
.../Interfaces/DestinationStyleOpInterface.cpp | 17 +++++++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 6 +++---
12 files changed, 67 insertions(+), 28 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index c779d1f843d76a0..86c013fd1323680 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -218,7 +218,8 @@ def Bufferization_MaterializeInDestinationOp
: Bufferization_Op<"materialize_in_destination",
[AllShapesMatch<["source", "dest"]>,
AllElementTypesMatch<["source", "dest"]>,
- BufferizableOpInterface, DestinationStyleOpInterface,
+ BufferizableOpInterface,
+ DeclareOpInterfaceMethods<DestinationStyleOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
@@ -301,8 +302,6 @@ def Bufferization_MaterializeInDestinationOp
return ::llvm::cast<RankedTensorType>(getResult().getType());
}
- MutableOperandRange getDpsInitsMutable();
-
bool isWritable(Value value, const AnalysisState &state);
}];
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index da12e7c83b22b89..ce2dfbd8d5f813e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -149,7 +149,9 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
int64_t getOutputOperandRank() {
return getOutputOperandType().getRank();
}
- MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
+ MutableArrayRef<OpOperand> getDpsInitsMutable() {
+ return getOutputMutable();
+ }
}];
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 21a5e5cc47aeb5c..74b4067855b86b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -208,7 +208,9 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
return nullptr;
}
- MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
+ MutableArrayRef<OpOperand> getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
}];
let hasCanonicalizer = 1;
@@ -281,7 +283,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
}
// Implement functions necessary for DestinationStyleOpInterface.
- MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+ MutableArrayRef<OpOperand> getDpsInitsMutable() { return getInitMutable(); }
SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
return getDpsInputOperands();
@@ -377,7 +379,9 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
getRegionBuilder() {
return nullptr;
}
- MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+ MutableArrayRef<OpOperand> getDpsInitsMutable() {
+ return getInitsMutable();
+ }
}];
let hasCustomAssemblyFormat = 1;
@@ -440,7 +444,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
}
// Implement functions necessary for DestinationStyleOpInterface.
- MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+ MutableArrayRef<OpOperand> getDpsInitsMutable() { return getInitMutable(); }
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
@@ -508,7 +512,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
}
// Implement functions necessary for DestinationStyleOpInterface.
- MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+ MutableArrayRef<OpOperand> getDpsInitsMutable() { return getInitMutable(); }
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index e1a604a88715f0e..f1c2fe969287fce 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -635,7 +635,9 @@ def ForallOp : SCF_Op<"forall", [
InParallelOp getTerminator();
// Declare the shared_outs as inits/outs to DestinationStyleOpInterface.
- MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
+ MutableArrayRef<OpOperand> getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
}];
}
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 86a250b77dcc8ee..b033a242143f4d3 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -750,7 +750,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
}];
let extraClassDeclaration = [{
- MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+ MutableArrayRef<OpOperand> getDpsInitsMutable() { return getDestMutable(); }
}];
let hasFolder = 1;
@@ -890,7 +890,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
- MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+ MutableArrayRef<OpOperand> getDpsInitsMutable() { return getDestMutable(); }
}];
let hasCanonicalizer = 1;
@@ -1710,7 +1710,7 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
RankedTensorType getDestType() {
return ::llvm::cast<RankedTensorType>(getDest().getType()); };
- MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+ MutableArrayRef<OpOperand> getDpsInitsMutable() { return getDestMutable(); }
/// Interface method for ConditionallySpeculatable.
Speculation::Speculatability getSpeculatability();
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..0e1d1af33904898 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1389,8 +1389,8 @@ def Vector_TransferReadOp :
// MaskableOpInterface methods.
bool supportsPassthru() { return true; }
- MutableOperandRange getDpsInitsMutable() {
- return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
+ MutableArrayRef<OpOperand> getDpsInitsMutable() {
+ return {};
}
}];
@@ -1553,7 +1553,9 @@ def Vector_TransferWriteOp :
/// ops of other dialects.
Value getValue() { return getVector(); }
- MutableOperandRange getDpsInitsMutable() { return getSourceMutable(); }
+ MutableArrayRef<OpOperand> getDpsInitsMutable() {
+ return getSourceMutable();
+ }
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index ed69e5824f70b51..51262e2d78716ec 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -158,6 +158,9 @@ class MutableOperandRange {
/// Allow implicit conversion to an OperandRange.
operator OperandRange() const;
+ /// Allow implicit conversion to a MutableArrayRef.
+ operator MutableArrayRef<OpOperand>() const;
+
/// Returns the owning operation.
Operation *getOwner() const { return owner; }
diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
index 4c52d803e114762..673e3d6160c212e 100644
--- a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
+++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
@@ -20,9 +20,9 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
Init operands must be ranked tensors or ranked memrefs. Input operands can
have any type. All non-init operands are DPS inputs.
- The init operands of this op are specified by the MutableOperandRange that
- the `getDpsInitsMutable` interface methods returns. This implies that the
- init operands must be a consecutive range of operands.
+ The init operands of this op are specified by the OpOperands that
+ the `getDpsInitsMutable` interface methods returns. The init operands must
+ be a consecutive range of operands.
If the op has "tensor semantics", then the input operands are either ranked
tensors or other non-tensor/memref types ("scalars"). The init operands are
@@ -56,8 +56,8 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
let methods = [
InterfaceMethod<
- /*desc=*/"Return start and end indices of the init operands range.",
- /*retTy=*/"::mlir::MutableOperandRange",
+ /*desc=*/"Return the consecutive range of init operands.",
+ /*retTy=*/"::llvm::MutableArrayRef<::mlir::OpOperand>",
/*methodName=*/"getDpsInitsMutable",
/*args=*/(ins)
>,
@@ -65,7 +65,13 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
let extraSharedClassDeclaration = [{
::mlir::OperandRange getDpsInits() {
- return $_op.getDpsInitsMutable();
+ auto initsMutable = $_op.getDpsInitsMutable();
+ if (initsMutable.empty())
+ return ::mlir::OperandRange($_op->operand_end(), $_op->operand_end());
+ unsigned firstOperandIndex = initsMutable.begin()->getOperandNumber();
+ return OperandRange(
+ $_op->operand_begin() + firstOperandIndex,
+ $_op->operand_begin() + firstOperandIndex + initsMutable.size());
}
/// Return the number of DPS inits.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 5716dcc9d905016..7109c1d9f31ad5b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -675,7 +675,7 @@ bool MaterializeInDestinationOp::isWritable(Value value,
return isa<TensorType>(getDest().getType()) ? true : getWritable();
}
-MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
+MutableArrayRef<OpOperand> MaterializeInDestinationOp::getDpsInitsMutable() {
return getDestMutable();
}
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 6726b49dd3d3103..41214b998bd6985 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -502,6 +502,10 @@ MutableOperandRange::operator OperandRange() const {
return owner->getOperands().slice(start, length);
}
+MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
+ return owner->getOpOperands().slice(start, length);
+}
+
MutableOperandRangeRange
MutableOperandRange::split(NamedAttribute segmentSizes) const {
return MutableOperandRangeRange(*this, segmentSizes);
@@ -529,11 +533,11 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
}
MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
- return owner->getOpOperands().slice(start, length).begin();
+ return static_cast<MutableArrayRef<OpOperand>>(*this).begin();
}
MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
- return owner->getOpOperands().slice(start, length).end();
+ return static_cast<MutableArrayRef<OpOperand>>(*this).end();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
index 4e5ef66887cadf8..b8a676a9595b781 100644
--- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
+++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
@@ -31,7 +31,24 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
cast<DestinationStyleOpInterface>(op);
SmallVector<OpOperand *> outputTensorOperands;
+#ifndef NDEBUG
+ int64_t lastOperandIdx;
+ if (!dstStyleOp.getDpsInitsMutable().empty())
+ lastOperandIdx =
+ static_cast<int64_t>(
+ dstStyleOp.getDpsInitsMutable().begin()->getOperandNumber()) -
+ 1;
+#endif // NDEBUG
for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
+#ifndef NDEBUG
+ // DPS inits must be consecutive operands. Since `getDpsInitsMutable`
+ // returns a MutableArrayRef (that does not own the underlying data), it is
+ // currently not possible to return non-consecutive operands and this check
+ // just guards against future changes of this interface.
+ assert(lastOperandIdx + 1 == operand.getOperandNumber() &&
+ "DPS inits must be consecutive operands");
+ ++lastOperandIdx;
+#endif // NDEBUG
Type type = operand.get().getType();
if (isa<RankedTensorType>(type)) {
outputTensorOperands.push_back(&operand);
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index edb63924b3553f2..75785e42c667569 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2350,7 +2350,7 @@ def TestDestinationStyleOp :
}];
let extraClassDeclaration = [{
- mlir::MutableOperandRange getDpsInitsMutable() {
+ mlir::MutableArrayRef<mlir::OpOperand> getDpsInitsMutable() {
return getOutputsMutable();
}
}];
@@ -2411,7 +2411,7 @@ def TestLinalgConvOp :
return "";
}
- mlir::MutableOperandRange getDpsInitsMutable() {
+ mlir::MutableArrayRef<mlir::OpOperand> getDpsInitsMutable() {
return getOutputsMutable();
}
}];
@@ -2472,7 +2472,7 @@ def TestLinalgFillOp :
return "";
}
- mlir::MutableOperandRange getDpsInitsMutable() {
+ mlir::MutableArrayRef<mlir::OpOperand> getDpsInitsMutable() {
return getOutputsMutable();
}
}];
More information about the Mlir-commits
mailing list