[Mlir-commits] [mlir] b4db15a - [mlir] Rename getInputs->getDpsInputs and getOutputs->getDpsInits in DPS interface.
Alexander Belyaev
llvmlistbot at llvm.org
Fri Oct 28 06:41:28 PDT 2022
Author: Alexander Belyaev
Date: 2022-10-28T15:41:12+02:00
New Revision: b4db15a949646f45011f31c58133adab59f8ddb0
URL: https://github.com/llvm/llvm-project/commit/b4db15a949646f45011f31c58133adab59f8ddb0
DIFF: https://github.com/llvm/llvm-project/commit/b4db15a949646f45011f31c58133adab59f8ddb0.diff
LOG: [mlir] Rename getInputs->getDpsInputs and getOutputs->getDpsInits in DPS interface.
https://discourse.llvm.org/t/rfc-interface-for-destination-style-ops/64056
Differential Revision: https://reviews.llvm.org/D136943
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h
mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h
index f1b1c658e872d..4fc88eb9277cb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h
@@ -33,14 +33,14 @@ struct DstBufferizableOpInterfaceExternalModel
const AnalysisState &state) const {
// Only outputs bufferize to a memory write.
auto dstOp = cast<DestinationStyleOpInterface>(op);
- return dstOp.isOutput(&opOperand);
+ return dstOp.isDpsInit(&opOperand);
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// Output operands alias with their respective tied OpResults.
auto dstOp = cast<DestinationStyleOpInterface>(op);
- if (dstOp.isOutput(&opOperand))
+ if (dstOp.isDpsInit(&opOperand))
return {dstOp.getTiedOpResult(&opOperand)};
return {};
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index e49fe0c54add7..b34973ce8859b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -83,7 +83,7 @@ class LinalgDependenceGraph {
return llvm::None;
if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
return owner.getMatchingIndexingMap(operand);
- return owner.getMatchingIndexingMap(owner.getOutputOperand(
+ return owner.getMatchingIndexingMap(owner.getDpsInitOperand(
opView.get<Value>().cast<OpResult>().getResultNumber()));
}
// Return the operand number if the `opView` is an OpOperand *. Otherwise
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 28bf0b0d3618f..533a52f8f3271 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -289,7 +289,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
getNumIterators(getReductionIteratorTypeName(), iters) == 1;
}]>,
//===------------------------------------------------------------------===//
- // Input and Output arguments handling.
+ // Input and Init arguments handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
@@ -317,7 +317,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- if (!$_op.isOutput(opOperand))
+ if (!$_op.isDpsInit(opOperand))
return false;
return payloadUsesValueFromOperand(opOperand);
}]
@@ -353,7 +353,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
// TODO: reevalute the need for a cast when a better mechanism exists.
return getBlock()->getArguments().take_front(
cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumInputs());
+ .getNumDpsInputs());
}]
>,
InterfaceMethod<
@@ -371,7 +371,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
// TODO: reevalute the need for a cast when a better mechanism exists.
return getBlock()->getArguments().take_back(
cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumOutputs());
+ .getNumDpsInits());
}]
>,
InterfaceMethod<
@@ -450,7 +450,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
// TODO: reevalute the need for a cast when a better mechanism exists.
return *(indexingMaps.begin() +
cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumInputs() +
+ .getNumDpsInputs() +
result.getResultNumber());
}]
>,
@@ -472,7 +472,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
int64_t resultIndex =
opOperand->getOperandNumber() -
cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumInputs();
+ .getNumDpsInputs();
assert(resultIndex >= 0 &&
resultIndex < this->getOperation()->getNumResults());
Operation *yieldOp = getBlock()->getTerminator();
@@ -780,49 +780,49 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
// TODO: reevalute the need for a cast when a better mechanism exists.
//========================================================================//
- int64_t getNumInputs() {
+ int64_t getNumDpsInputs() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumInputs();
+ .getNumDpsInputs();
}
- int64_t getNumOutputs() {
+ int64_t getNumDpsInits() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumOutputs();
+ .getNumDpsInits();
}
- OpOperandVector getInputOperands() {
+ OpOperandVector getDpsInputOperands() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getInputOperands();
+ .getDpsInputOperands();
}
- OpOperand *getInputOperand(int64_t i) {
+ OpOperand *getDpsInputOperand(int64_t i) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getInputOperand(i);
+ .getDpsInputOperand(i);
}
- void setOutputOperand(int64_t i, Value value) {
+ void setDpsInitOperand(int64_t i, Value value) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .setOutputOperand(i, value);
+ .setDpsInitOperand(i, value);
}
- OpOperandVector getOutputOperands() {
+ OpOperandVector getDpsInitOperands() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getOutputOperands();
+ .getDpsInitOperands();
}
- OpOperand *getOutputOperand(int64_t i) {
+ OpOperand *getDpsInitOperand(int64_t i) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getOutputOperand(i);
+ .getDpsInitOperand(i);
}
- bool isInput(OpOperand *opOperand) {
+ bool isDpsInput(OpOperand *opOperand) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .isInput(opOperand);
+ .isDpsInput(opOperand);
}
- bool isOutput(OpOperand *opOperand) {
+ bool isDpsInit(OpOperand *opOperand) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .isOutput(opOperand);
+ .isDpsInit(opOperand);
}
bool isScalar(OpOperand *opOperand) {
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 510f8831f019a..b067a1ddd1e61 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -216,7 +216,7 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
getRegionBuilder() {
return nullptr;
}
- std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - getOutputs().size(), getNumOperands};
}
@@ -282,16 +282,16 @@ def MapOp : LinalgStructuredBase_Op<"map", [
}
// Implement functions necessary for DestinationStyleOpInterface.
- std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
}
OpOperandVector getOpOperandsMatchingBBargs() {
- return getInputOperands();
+ return getDpsInputOperands();
}
bool payloadUsesValueFromOperand(OpOperand * opOperand) {
- if (isOutput(opOperand)) return false;
+ if (isDpsInit(opOperand)) return false;
return !getMatchingBlockArgument(opOperand).use_empty();
}
@@ -368,7 +368,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
getRegionBuilder() {
return nullptr;
}
- std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {getInits().size(), getNumOperands()};
}
}];
@@ -433,7 +433,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
}
// Implement functions necessary for DestinationStyleOpInterface.
- std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
}
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index e2451f9429593..2cfdc6d8c6feb 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -723,7 +723,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
}]>];
let extraClassDeclaration = [{
- std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {1, 2}; // `dest` operand
}
}];
@@ -868,7 +868,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
- std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {1, 2}; // `dest` operand
}
}];
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e0ec6f4ffb080..b47c5fa32904e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1397,7 +1397,7 @@ def Vector_TransferWriteOp :
/// ops of other dialects.
Value getValue() { return getVector(); }
- std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {1, 2}; // `source` operand
}
}];
diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
index 0e7662ddecfe0..75f7477dca636 100644
--- a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
+++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
@@ -13,27 +13,27 @@ include "mlir/IR/OpBase.td"
def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
let description = [{
- Ops that are in destination style have designated output operands, which act
- as initial tensor values for the results of the operation or the output
+ Ops that are in destination style have designated init operands, which act
+ as initial tensor values for the results of the operation or the init
buffers to which the results of the op will be written.
- Output operands must be ranked tensors or ranked memrefs. Input operands can
- have any type. All non-output operands are inputs.
+ Init operands must be ranked tensors or ranked memrefs. Input operands can
+ have any type. All non-init operands are DPS inputs.
- It is assumed that the output operands of the op are the operands at
- position [start, end). The positions are defined by getOutputsPositionRange
- method. All non-output operands are "inputs" of the DPS op.
+ It is assumed that the init operands of the op are the operands at
+ position [start, end). The positions are defined by getDpsInitsPositionRange
+ method.
If the op has "tensor semantics", then the input operands are either scalars
- or ranked tensors. The output operands are ranked tensors and every tensor
- output is tied to a corresponding tensor OpResult in a 1-to-1 fashion.
- The i-th output tensor is tied to the i-th OpResult. The op may not have any
- additional OpResults. Output operands and their tied OpResults have the same
+ or ranked tensors. The init operands are ranked tensors and every tensor
+ init is tied to a corresponding tensor OpResult in a 1-to-1 fashion.
+ The i-th init tensor is tied to the i-th OpResult. The op may not have any
+ additional OpResults. Init operands and their tied OpResults have the same
type.
If the op has "buffer semantics", then the input operands are either ranked
memrefs or other non-tensor types, e.g. scalar types. Furthermore, the
- output operands are ranked memrefs and the op has no results.
+ init operands are ranked memrefs and the op has no results.
Destination-passing style abstraction makes certain transformations easier.
For example, tiling implementation can extract/insert slices from/into the
@@ -43,7 +43,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
bufferization) and can directly reuse the existing destination buffer.
Example of a destination style op: `%r = tensor.insert_slice %t into %d`,
- where `%t` is the single input and `%d` is the single output. `%d` is tied
+ where `%t` is the single input and `%d` is the single init. `%d` is tied
to `%r`.
Example of an op that is not in destination style: `%r = tensor.pad %t`.
@@ -51,7 +51,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
shape.
Each op that wants to implement DestinationStyleOpInterface needs to define
- the getOutputsPositionRange() method.
+ the getDpsInitsPositionRange() method.
}];
let cppNamespace = "::mlir";
@@ -59,9 +59,9 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
let methods = [
// This method has to be defined for every DPS op.
InterfaceMethod<
- /*desc=*/"Return start and end indices of the output operands range.",
+ /*desc=*/"Return start and end indices of the init operands range.",
/*retTy=*/"std::pair<int64_t, int64_t>",
- /*methodName=*/"getOutputsPositionRange",
+ /*methodName=*/"getDpsInitsPositionRange",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
@@ -70,27 +70,27 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
// Operands handling.
//===------------------------------------------------------------------===//
// The operand list is assumed to start with the input operands and end
- // with the output operands. Therefore, all methods to access the inputs
- // and outputs can be expressed if the number of output operands is know.
+ // with the init operands. Therefore, all methods to access the inputs
+ // and inits can be expressed if the number of init operands is know.
InterfaceMethod<
- /*desc=*/"Return the number of outputs.",
+ /*desc=*/"Return the number of inits.",
/*retTy=*/"int64_t",
- /*methodName=*/"getNumOutputs",
+ /*methodName=*/"getNumDpsInits",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto [start, end] = $_op.getOutputsPositionRange();
+ auto [start, end] = $_op.getDpsInitsPositionRange();
return end - start;
}]
>,
InterfaceMethod<
- /*desc=*/"Return the output operands.",
+ /*desc=*/"Return the init operands.",
/*retTy=*/"OpOperandVector",
- /*methodName=*/"getOutputOperands",
+ /*methodName=*/"getDpsInitOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto [start, end] = $_op.getOutputsPositionRange();
+ auto [start, end] = $_op.getDpsInitsPositionRange();
OpOperandVector result;
result.reserve(end - start);
@@ -100,52 +100,52 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
}]
>,
InterfaceMethod<
- /*desc=*/"Return the `i`-th output operand.",
+ /*desc=*/"Return the `i`-th init operand.",
/*retTy=*/"OpOperand *",
- /*methodName=*/"getOutputOperand",
+ /*methodName=*/"getDpsInitOperand",
/*args=*/(ins "int64_t":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(i >= 0 && i < $_op.getNumOutputs());
- auto [start, end] = $_op.getOutputsPositionRange();
+ assert(i >= 0 && i < $_op.getNumDpsInits());
+ auto [start, end] = $_op.getDpsInitsPositionRange();
return &$_op->getOpOperand(start + i);
}]
>,
InterfaceMethod<
- /*desc=*/"Set the `i`-th output operand.",
+ /*desc=*/"Set the `i`-th init operand.",
/*retTy=*/"void",
- /*methodName=*/"setOutputOperand",
+ /*methodName=*/"setDpsInitOperand",
/*args=*/(ins "int64_t":$i, "Value":$value),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(i >= 0 && i < $_op.getNumOutputs());
- auto [start, end] = $_op.getOutputsPositionRange();
+ assert(i >= 0 && i < $_op.getNumDpsInits());
+ auto [start, end] = $_op.getDpsInitsPositionRange();
$_op->setOperand(start + i, value);
}]
>,
InterfaceMethod<
/*desc=*/"Return the number of inputs.",
/*retTy=*/"int64_t",
- /*methodName=*/"getNumInputs",
+ /*methodName=*/"getNumDpsInputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return $_op.getNumOperands() - $_op.getNumOutputs();
+ return $_op.getNumOperands() - $_op.getNumDpsInits();
}]
>,
InterfaceMethod<
/*desc=*/"Return the input operands.",
/*retTy=*/"OpOperandVector",
- /*methodName=*/"getInputOperands",
+ /*methodName=*/"getDpsInputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto [start, end] = $_op.getOutputsPositionRange();
- int64_t numOutputs = end - start;
+ auto [start, end] = $_op.getDpsInitsPositionRange();
+ int64_t numInits = end - start;
int64_t numOperands = $_op.getNumOperands();
OpOperandVector result;
- result.reserve(numOperands - numOutputs);
+ result.reserve(numOperands - numInits);
for (int i = 0; i < start; ++i)
result.push_back(&$_op->getOpOperand(i));
for (int i = end; i < numOperands; ++i)
@@ -157,38 +157,38 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
InterfaceMethod<
/*desc=*/[{ Return the `i`-th input operand. }],
/*retTy=*/"OpOperand *",
- /*methodName=*/"getInputOperand",
+ /*methodName=*/"getDpsInputOperand",
/*args=*/(ins "int64_t":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(i >= 0 && i < getNumInputs());
- auto [start, end] = $_op.getOutputsPositionRange();
+ assert(i >= 0 && i < getNumDpsInputs());
+ auto [start, end] = $_op.getDpsInitsPositionRange();
return &$_op->getOpOperand(i < start ? i : i + end - start) ;
}]
>,
//===------------------------------------------------------------------===//
- // Input and Output arguments handling.
+ // Input and DpsInit arguments handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/"Return true if `opOperand` is an input.",
/*retTy=*/"bool",
- /*methodName=*/"isInput",
+ /*methodName=*/"isDpsInput",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto [start, end] = $_op.getOutputsPositionRange();
+ auto [start, end] = $_op.getDpsInitsPositionRange();
auto operandNumber = opOperand->getOperandNumber();
return operandNumber < start || operandNumber >= end;
}]
>,
InterfaceMethod<
- /*desc=*/"Return true if `opOperand` is an output.",
+ /*desc=*/"Return true if `opOperand` is an init.",
/*retTy=*/"bool",
- /*methodName=*/"isOutput",
+ /*methodName=*/"isDpsInit",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto [start, end] = $_op.getOutputsPositionRange();
+ auto [start, end] = $_op.getDpsInitsPositionRange();
auto operandNumber = opOperand->getOperandNumber();
return operandNumber >= start && operandNumber < end;
}]
@@ -213,7 +213,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == $_op.getOperation());
- auto [start, end] = $_op.getOutputsPositionRange();
+ auto [start, end] = $_op.getDpsInitsPositionRange();
int64_t resultIndex = opOperand->getOperandNumber() - start;
assert(resultIndex >= 0 &&
resultIndex < $_op->getNumResults() );
@@ -228,14 +228,14 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opResult.getDefiningOp() == $_op.getOperation());
- return $_op.getOutputOperand(opResult.getResultNumber());
+ return $_op.getDpsInitOperand(opResult.getResultNumber());
}]
>,
//===------------------------------------------------------------------===//
// Other interface methods.
//===------------------------------------------------------------------===//
InterfaceMethod<
- /*desc=*/"Return whether the op has only ranked MemRef inputs/outputs.",
+ /*desc=*/"Return whether the op has only ranked MemRef input/inits.",
/*retTy=*/"bool",
/*methodName=*/"hasBufferSemantics",
/*args=*/(ins),
@@ -250,7 +250,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
}]
>,
InterfaceMethod<
- /*desc=*/"Return whether the op has only ranked tensor inputs/outputs.",
+ /*desc=*/"Return whether the op has only ranked tensor inputs/inits.",
/*retTy=*/"bool",
/*methodName=*/"hasTensorSemantics",
/*args=*/(ins),
@@ -270,7 +270,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
/*desc=*/[{
Clone the current operation with the given location and operands. This
is used to abstract away the optional underlying region creation. This
- does not change the balance between input, output_buffer and
+ does not change the balance between input, init_buffer and
init_tensors operands.
}],
/*retTy=*/"Operation *",
@@ -292,7 +292,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
Clone the current operation with the given location, operands
and BlockAndValueMapping but leave the regions empty. This is
used to abstract away the optional underlying region creation.
- This does not change the balance between input, output_buffer
+ This does not change the balance between input, init_buffer
and init_tensors operands.
}],
/*retTy=*/"Operation *",
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index 305ea9c0f32fb..866d41435d849 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -70,7 +70,7 @@ SingleWorkgroupReduction::matchAsPerformingReduction(
return llvm::None;
// Make sure this is reduction with one input and one output.
- if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1)
+ if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
return llvm::None;
auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index c4c7efb0b7c0f..e72cf5d16659b 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -165,7 +165,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation()
<< " and " << *dst.getOperation() << "\n");
if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
- for (OpOperand *dstOpOperand : dst.getInputOperands()) {
+ for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) {
if (!dstOpOperand->get().getType().isa<RankedTensorType>())
continue;
// Check if the operand is defined by the src.
@@ -174,7 +174,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
addDependenceElem(DependenceType::RAW, dstOpOperand->get(),
dstOpOperand);
}
- for (OpOperand *dstOpOperand : dst.getOutputOperands()) {
+ for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) {
// Check if the operand is defined by the src.
auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
if (definingOp && definingOp == src) {
@@ -190,31 +190,31 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
}
assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
"unhandled dependence tracking for mixed buffer/tensor operations");
- for (OpOperand *srcOpOperand : src.getOutputOperands()) { // W
+ for (OpOperand *srcOpOperand : src.getDpsInitOperands()) { // W
// RAW graph
- for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R
+ for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { // R
if (!dstOpOperand->get().getType().isa<MemRefType>())
continue;
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias
addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand);
}
// WAW graph
- for (OpOperand *dstOpOperand : dst.getOutputOperands()) // W
+ for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) // W
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias
addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand);
}
- for (OpOperand *srcOpOperand : src.getInputOperands()) { // R
+ for (OpOperand *srcOpOperand : src.getDpsInputOperands()) { // R
if (!srcOpOperand->get().getType().isa<MemRefType>())
continue;
// RAR graph
- for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R
+ for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { // R
if (!dstOpOperand->get().getType().isa<MemRefType>())
continue;
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias
addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand);
}
// WAR graph
- for (OpOperand *dstOpOperand : dst.getOutputOperands()) // W
+ for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) // W
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias
addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand);
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 78e84900e1626..78a29f41e9fb8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -119,7 +119,7 @@ static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return MatchContractionResult::NotLinalgOp;
- if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
+ if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
return MatchContractionResult::WrongNumOperands;
auto mapRange = linalgOp.getIndexingMapsArray();
if (linalgOp.getNumReductionLoops() == 0)
@@ -278,7 +278,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return MatchConvolutionResult::NotLinalgOp;
- if (linalgOp.getNumInputs() < 2 || linalgOp.getNumOutputs() != 1)
+ if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
return MatchConvolutionResult::WrongNumOperands;
auto indexingMaps = linalgOp.getIndexingMapsArray();
@@ -436,10 +436,10 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return MatchFillResult::NotLinalgOp;
- if (linalgOp.getNumInputs() != 1 || linalgOp.getNumOutputs() != 1)
+ if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
return MatchFillResult::WrongNumOperands;
- OpOperand *value = linalgOp.getInputOperand(0);
+ OpOperand *value = linalgOp.getDpsInputOperand(0);
if (!linalgOp.isScalar(value))
return MatchFillResult::NotScalarInput;
@@ -555,9 +555,9 @@ static std::pair<int64_t, int64_t>
getResultsPositionInLoopsToShapeMap(LinalgOp &op) {
int64_t inputRankSum = 0;
int64_t outputRankSum = 0;
- for (OpOperand *input : op.getInputOperands())
+ for (OpOperand *input : op.getDpsInputOperands())
inputRankSum += op.getRank(input);
- for (OpOperand *output : op.getOutputOperands())
+ for (OpOperand *output : op.getDpsInitOperands())
outputRankSum += op.getRank(output);
return {inputRankSum, inputRankSum + outputRankSum};
}
@@ -601,7 +601,7 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
createFlatListOfOperandDims(b, loc));
int64_t pos = 0;
ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
- for (OpOperand *opOperand : getOutputOperands()) {
+ for (OpOperand *opOperand : getDpsInitOperands()) {
SmallVector<Value> shapes;
for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) {
if (checkDimExpr.visit(shapeExprs[pos]))
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e9f26306d58b7..00893e63bf843 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -767,8 +767,8 @@ void GenericOp::print(OpAsmPrinter &p) {
}
// Printing is shared with named ops, except for the region and attributes
- printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
- SmallVector<Value>(getOutputOperands()));
+ printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
genericAttrNames.push_back("operand_segment_sizes");
genericAttrNamesSet.insert(genericAttrNames.back());
@@ -858,7 +858,7 @@ void GenericOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getGenericEffectsImpl(effects, getOperation()->getResults(),
- getInputOperands(), getOutputOperands());
+ getDpsInputOperands(), getDpsInitOperands());
}
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
@@ -866,7 +866,7 @@ static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
return false;
// If out operand not used in payload, we can drop it.
OpOperand *outputOpOperand =
- genericOp.getOutputOperand(result.getResultNumber());
+ genericOp.getDpsInitOperand(result.getResultNumber());
if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
return true;
@@ -981,7 +981,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
SmallVector<AffineMap> &newIndexingMaps) const {
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
- for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) {
+ for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
OpOperand *inputOpOperand = en.value();
// Check if operand is dead and if dropping the indexing map makes the
// loops to shape computation invalid.
@@ -1029,7 +1029,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
// If the op doesnt have tensor semantics, keep all the outputs as
// preserved.
if (!genericOp.hasTensorSemantics()) {
- for (const auto &en : llvm::enumerate(genericOp.getOutputOperands())) {
+ for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) {
origToNewPos[en.index()] = newOutputOperands.size();
newOutputOperands.push_back(en.value()->get());
newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value()));
@@ -1043,7 +1043,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
// computation.
auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
for (const auto &outputOpOperand :
- llvm::enumerate(genericOp.getOutputOperands())) {
+ llvm::enumerate(genericOp.getDpsInitOperands())) {
OpResult result = genericOp.getTiedOpResult(outputOpOperand.value());
AffineMap indexingMap =
genericOp.getMatchingIndexingMap(outputOpOperand.value());
@@ -1111,22 +1111,22 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
}
};
- OpOperandVector origInputOperands = genericOp.getInputOperands();
- OpOperandVector newInputOperands = newOp.getInputOperands();
+ OpOperandVector origInputOperands = genericOp.getDpsInputOperands();
+ OpOperandVector newInputOperands = newOp.getDpsInputOperands();
updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
- OpOperandVector origOutputOperands = genericOp.getOutputOperands();
- OpOperandVector newOutputOperands = newOp.getOutputOperands();
+ OpOperandVector origOutputOperands = genericOp.getDpsInitOperands();
+ OpOperandVector newOutputOperands = newOp.getDpsInitOperands();
updateReplacements(origOutputOperands, newOutputOperands,
origOutsToNewOutsPos);
// Drop the unused yield args.
- if (newOp.getNumOutputs() != genericOp.getNumOutputs()) {
+ if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
OpBuilder::InsertionGuard g(rewriter);
YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
rewriter.setInsertionPoint(origYieldOp);
- SmallVector<Value> newYieldVals(newOp.getNumOutputs(), nullptr);
+ SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
for (const auto &yieldOpOperands :
llvm::enumerate(origYieldOp.getValues())) {
auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
@@ -1167,9 +1167,9 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
// In the buffer case, we need to check exact buffer equality.
if (genericOp.hasBufferSemantics()) {
- if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 &&
- genericOp.getInputOperand(0)->get() ==
- genericOp.getOutputOperand(0)->get()) {
+ if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
+ genericOp.getDpsInputOperand(0)->get() ==
+ genericOp.getDpsInitOperand(0)->get()) {
rewriter.eraseOp(genericOp);
return success();
}
@@ -1238,7 +1238,7 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
bool hasRemovedCycles = false;
// Iterate over output operands and remove any unused cycles.
for (const auto &outputOpOperand :
- llvm::enumerate(genericOp.getOutputOperands())) {
+ llvm::enumerate(genericOp.getDpsInitOperands())) {
// Check that result from out operand is dead.
Value result = genericOp.getResult(outputOpOperand.index());
@@ -1370,8 +1370,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
void MapOp::print(OpAsmPrinter &p) {
p.increaseIndent();
printCommonStructuredOpPartsWithNewLine(
- p, SmallVector<Value>(getInputOperands()),
- SmallVector<Value>(getOutputOperands()));
+ p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
p.printOptionalAttrDict((*this)->getAttrs());
p.printNewline();
@@ -1405,8 +1405,7 @@ LogicalResult MapOp::verify() {
}
// The shape of each input must match the shape of the output.
- auto outputShape =
- getOutputOperand(0)->get().getType().cast<ShapedType>().getShape();
+ auto outputShape = getInit().getType().cast<ShapedType>().getShape();
for (Type inputArgType : TypeRange{getInputs()}) {
auto inputElemShape = inputArgType.cast<ShapedType>().getShape();
if (inputElemShape != outputShape) {
@@ -1436,7 +1435,7 @@ void MapOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getGenericEffectsImpl(effects, getOperation()->getResults(),
- getInputOperands(), getOutputOperands());
+ getDpsInputOperands(), getDpsInitOperands());
}
//===----------------------------------------------------------------------===//
@@ -1488,12 +1487,12 @@ SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
ArrayAttr ReduceOp::getIndexingMaps() {
int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
SmallVector<AffineMap> affineMaps(
- getNumInputs(),
+ getNumDpsInputs(),
AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
AffineMap resultMap =
AffineMap::getMultiDimIdentityMap(inputRank, getContext())
.dropResults(getDimensions());
- for (int64_t i = 0, e = getNumOutputs(); i < e; ++i)
+ for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
affineMaps.push_back(resultMap);
return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
}
@@ -1502,7 +1501,7 @@ void ReduceOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getGenericEffectsImpl(effects, getOperation()->getResults(),
- getInputOperands(), getOutputOperands());
+ getDpsInputOperands(), getDpsInitOperands());
}
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1543,9 +1542,10 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
void ReduceOp::print(OpAsmPrinter &p) {
p.increaseIndent();
printCommonStructuredOpPartsWithNewLine(
- p, SmallVector<Value>(getInputOperands()),
- SmallVector<Value>(getOutputOperands()));
+ p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
p.printNewline();
+
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
@@ -1562,7 +1562,7 @@ void ReduceOp::print(OpAsmPrinter &p) {
LogicalResult ReduceOp::verify() {
ArrayRef<int64_t> dimensionsRef = getDimensions();
- for (int64_t i = 1; i < getNumInputs(); ++i) {
+ for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
if (getInputs()[i].getType().cast<ShapedType>().getShape() !=
getInputs()[0].getType().cast<ShapedType>().getShape()) {
return emitOpError() << "expects all inputs to have the same shapes. "
@@ -1571,7 +1571,7 @@ LogicalResult ReduceOp::verify() {
<< " is not equal to the shape at input-index 0.";
}
}
- for (int64_t i = 1; i < getNumOutputs(); ++i) {
+ for (int64_t i = 1; i < getNumDpsInits(); ++i) {
if (getInits()[i].getType().cast<ShapedType>().getShape() !=
getInits()[0].getType().cast<ShapedType>().getShape()) {
return emitOpError() << "expects all outputs to have the same shapes. "
@@ -1632,8 +1632,8 @@ LogicalResult ReduceOp::verify() {
// Check that the last block arguments match the element type of the outputs.
for (auto [output, bbArg] :
- llvm::zip(getOutputOperands(),
- block->getArguments().take_back(getNumOutputs()))) {
+ llvm::zip(getDpsInitOperands(),
+ block->getArguments().take_back(getNumDpsInits()))) {
auto outputElementType =
output->get().getType().cast<ShapedType>().getElementType();
if (outputElementType != bbArg.getType())
@@ -1712,9 +1712,10 @@ void TransposeOp::getAsmResultNames(
void TransposeOp::print(OpAsmPrinter &p) {
p.increaseIndent();
printCommonStructuredOpPartsWithNewLine(
- p, SmallVector<Value>(getInputOperands()),
- SmallVector<Value>(getOutputOperands()));
+ p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
p.printNewline();
+
printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
p.decreaseIndent();
@@ -1774,7 +1775,7 @@ void TransposeOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getGenericEffectsImpl(effects, getOperation()->getResults(),
- getInputOperands(), getOutputOperands());
+ getDpsInputOperands(), getDpsInitOperands());
}
//===----------------------------------------------------------------------===//
@@ -1802,15 +1803,15 @@ ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
// Check the operand number and types must match the element types of the
// LinalgOp interface's shaped operands.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
- if (op.getNumOperands() != linalgOp.getNumOutputs())
+ if (op.getNumOperands() != linalgOp.getNumDpsInits())
return op.emitOpError("expected number of yield values (")
- << linalgOp.getNumOutputs()
+ << linalgOp.getNumDpsInits()
<< ") to match the number of operands of the enclosing "
<< "LinalgOp (" << op.getNumOperands() << ")";
for (OpOperand &opOperand : op->getOpOperands()) {
OpOperand *outputOperand =
- linalgOp.getOutputOperand(opOperand.getOperandNumber());
+ linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
if (opOperand.get().getType() != elementType)
return op.emitOpError("type of yield operand ")
@@ -1981,14 +1982,14 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
SmallVector<Value, 4> newOperands;
newOperands.reserve(op->getNumOperands());
// Inputs may fold.
- for (auto *input : op.getInputOperands()) {
+ for (auto *input : op.getDpsInputOperands()) {
auto tensorCastOp = input->get().getDefiningOp<tensor::CastOp>();
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
? tensorCastOp.getSource()
: input->get());
}
// Init tensors may fold, in which case the resultType must also change.
- for (auto *output : op.getOutputOperands()) {
+ for (auto *output : op.getDpsInitOperands()) {
auto tensorCastOp = output->get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get());
@@ -2047,11 +2048,11 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
// for this cast, i.e. producer of the out operand, is also an operation
// that folds with tensor.cast consumer (like this pattern), the cast will
// continue to propagate as far up the stack as it can go.
- OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
+ OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
Value newOperand =
rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
- SmallVector<Value> newOperands{linalgOp.getInputOperands()};
- SmallVector<Value> outputOperands{linalgOp.getOutputOperands()};
+ SmallVector<Value> newOperands{linalgOp.getDpsInputOperands()};
+ SmallVector<Value> outputOperands{linalgOp.getDpsInitOperands()};
outputOperands[resultNumber] = newOperand;
newOperands.append(outputOperands.begin(), outputOperands.end());
@@ -2124,7 +2125,7 @@ static void createNewOperandWithStaticSizes(
return;
auto sourceType = src.getType().cast<RankedTensorType>();
Type resultType = sourceType;
- if (sourceType.hasStaticShape() && linalgOp.isOutput(opOperand)) {
+ if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
resultTypes.push_back(resultType);
return;
}
@@ -2157,7 +2158,7 @@ static void createNewOperandWithStaticSizes(
unsigned index = opOperand->getOperandNumber();
newOperands[index] = newOperand;
}
- if (linalgOp.isOutput(opOperand))
+ if (linalgOp.isDpsInit(opOperand))
resultTypes.push_back(resultType);
}
@@ -2193,7 +2194,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
// change in their types.
bool changeNeeded = false;
newOperands.reserve(linalgOp->getNumOperands());
- resultTypes.reserve(linalgOp.getNumOutputs());
+ resultTypes.reserve(linalgOp.getNumDpsInits());
// Iterate over all the operands and update the static sizes.
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
index 383a9267cdd0d..c75151a6cfd83 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
@@ -63,7 +63,7 @@ struct BubbleUpExtractSliceOpPattern
"expected single use of linalg op");
}
- if (linalgOp.getNumOutputs() != 1) {
+ if (linalgOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(sliceOp,
"expected single output of linalg op");
}
@@ -80,7 +80,7 @@ struct BubbleUpExtractSliceOpPattern
return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
}
- OpOperand *outOperand = linalgOp.getOutputOperand(0);
+ OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
if (!indexingMap.isProjectedPermutation()) {
return rewriter.notifyMatchFailure(
@@ -119,7 +119,7 @@ struct BubbleUpExtractSliceOpPattern
/*omitPartialTileCheck=*/true);
SmallVector<Type, 4> resultTensorTypes;
- for (OpOperand *opOperand : linalgOp.getOutputOperands())
+ for (OpOperand *opOperand : linalgOp.getDpsInitOperands())
resultTensorTypes.push_back(
tiledOperands[opOperand->getOperandNumber()].getType());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index ca17a21bcc89f..f954ac08d2201 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -41,8 +41,8 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
// New input operands for the cloned op.
SmallVector<Value> newInputBuffers;
- newInputBuffers.reserve(op.getNumInputs());
- for (OpOperand *opOperand : op.getInputOperands()) {
+ newInputBuffers.reserve(op.getNumDpsInputs());
+ for (OpOperand *opOperand : op.getDpsInputOperands()) {
if (op.isScalar(opOperand)) {
newInputBuffers.push_back(opOperand->get());
continue;
@@ -56,7 +56,7 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
// New output operands for the cloned op.
SmallVector<Value> newOutputBuffers;
for (OpResult opResult : op->getOpResults()) {
- OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
+ OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand->get(), options);
if (failed(resultBuffer))
@@ -111,7 +111,7 @@ struct LinalgOpInterface
auto genericOp = cast<DestinationStyleOpInterface>(op);
// The i-th OpResult may alias with the i-th "out" tensor.
- return {genericOp.getOutputOperand(opResult.getResultNumber())};
+ return {genericOp.getDpsInitOperand(opResult.getResultNumber())};
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
@@ -119,7 +119,7 @@ struct LinalgOpInterface
auto genericOp = cast<DestinationStyleOpInterface>(op);
// The i-th "out" tensor may alias with the i-th OpResult.
- if (genericOp.isOutput(&opOperand))
+ if (genericOp.isDpsInit(&opOperand))
return {genericOp.getTiedOpResult(&opOperand)};
return {};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index a21e0fc769d68..d0efba5a98938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -59,7 +59,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
return failure();
// Only support ops generating one output for now.
- if (genericOp.getNumOutputs() != 1)
+ if (genericOp.getNumDpsInits() != 1)
return failure();
auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
@@ -95,7 +95,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
[](AffineMap map) { return map.isPermutation(); }))
return failure();
- for (OpOperand *operand : genericOp.getOutputOperands()) {
+ for (OpOperand *operand : genericOp.getDpsInitOperands()) {
if (genericOp.payloadUsesValueFromOperand(operand))
return failure();
}
@@ -112,9 +112,9 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
return failure();
// All inputs should be constants.
- int numInputs = genericOp.getNumInputs();
+ int numInputs = genericOp.getNumDpsInputs();
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
- for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) {
+ for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
if (!matchPattern(en.value()->get(),
m_Constant(&inputValues[en.index()])))
return failure();
@@ -122,7 +122,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
// Identified this as a potential candidate for folding. Now check the
// policy to see whether we are allowed to proceed.
- for (OpOperand *operand : genericOp.getInputOperands()) {
+ for (OpOperand *operand : genericOp.getDpsInputOperands()) {
if (!controlFn(operand))
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 327e8bec9a7e7..bbba218c0f7b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -176,7 +176,8 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
}
}
if (resultNumber) {
- newInitValues.push_back(genericOp.getOutputOperand(*resultNumber)->get());
+ newInitValues.push_back(
+ genericOp.getDpsInitOperand(*resultNumber)->get());
OpResult result = genericOp.getResult(*resultNumber).cast<OpResult>();
newResultTypes.push_back(result.getType());
peeledGenericOpIndexingMaps.push_back(
@@ -224,7 +225,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
/// Add indexing maps for the newly added operands. Use the same map
/// as those used for the new results of the peeledGenericOp.
auto indexingMaps = llvm::to_vector(
- llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) {
+ llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) {
return genericOp.getMatchingIndexingMap(operand);
}));
for (auto resultNum :
@@ -233,7 +234,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
indexingMaps.push_back(
peeledGenericOp.getIndexingMapMatchingResult(result));
}
- for (OpOperand *outOperand : genericOp.getOutputOperands())
+ for (OpOperand *outOperand : genericOp.getDpsInitOperands())
indexingMaps.push_back(genericOp.getMatchingIndexingMap(outOperand));
auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
@@ -261,7 +262,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
genericOp, "only operations with tensor semantics are handled");
}
- if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
+ if (llvm::any_of(genericOp.getDpsInitOperands(), [&](OpOperand *outOperand) {
return !genericOp.getMatchingIndexingMap(outOperand).isPermutation();
})) {
return rewriter.notifyMatchFailure(
@@ -322,7 +323,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
/// In the split operations, replace block arguments uses that refer to
/// original operation to the block arguments of the newly created operation.
- unsigned origNumInputs = genericOp.getNumInputs();
+ unsigned origNumInputs = genericOp.getNumDpsInputs();
for (const auto &inputBlockArg :
llvm::enumerate(genericOp.getBody()->getArguments())) {
Value residualOpReplacementArg =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 2fcb2acf810ed..7b9b7358bb459 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -435,7 +435,8 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(genericOp.getNumResults());
for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
- resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
+ resultTypes.push_back(
+ newInputOutputTypes[i + genericOp.getNumDpsInputs()]);
GenericOp replacementOp = rewriter.create<GenericOp>(
loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
genericOp.getIteratorTypesArray());
@@ -447,7 +448,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
// the original shape.
SmallVector<Value, 4> resultReplacements;
for (const auto &result : llvm::enumerate(replacementOp.getResults())) {
- unsigned index = result.index() + replacementOp.getNumInputs();
+ unsigned index = result.index() + replacementOp.getNumDpsInputs();
auto origResultType = genericOp.getResult(result.index()).getType();
auto newResult = maybeExpand(result.value(), origResultType,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 80cef16724939..6a9c4e36a07e1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -90,7 +90,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
// Only allow fusing the producer of an input operand for now.
// TODO: allow fusing the producer of an output operand.
- if (!consumer.isInput(fusedOperand))
+ if (!consumer.isDpsInput(fusedOperand))
return false;
// Get the consumer index map. The number of results of the consumer index
@@ -102,7 +102,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
// Finally the index_map for the result must be invertible. For now just
// verify it is a permutation.
AffineMap producerResultIndexMap =
- producer.getMatchingIndexingMap(producer.getOutputOperand(0));
+ producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
if (!producerResultIndexMap.isPermutation())
return false;
@@ -128,7 +128,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
addToCoveredDims(operandMap);
}
- for (OpOperand *operand : producer.getInputOperands()) {
+ for (OpOperand *operand : producer.getDpsInputOperands()) {
AffineMap newIndexingMap =
getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
operand, producerResultIndexMap, consumerIndexMap);
@@ -179,7 +179,7 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp,
}
}
// TODO: allow fusing the producer of an output operand.
- assert(consumer.isInput(fusedOperand) &&
+ assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
// 3. Consumer input operands up to consumerIdx (exclusive).
for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
@@ -191,24 +191,24 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp,
// 4. Splice in producer's input operands.
for (BlockArgument bbArg :
- producerBlock.getArguments().take_front(producer.getNumInputs()))
+ producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
for (BlockArgument bbArg :
consumerBlock.getArguments()
- .take_front(consumer.getNumInputs())
+ .take_front(consumer.getNumDpsInputs())
.drop_front(fusedOperand->getOperandNumber() + 1))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
// 6. All of the producer's output operands
for (BlockArgument bbArg :
- producerBlock.getArguments().take_back(producer.getNumOutputs()))
+ producerBlock.getArguments().take_back(producer.getNumDpsInits()))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
// 7. All of consumer's output operands.
for (BlockArgument bbArg :
- consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
+ consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
// 8. Clone all producer operations except for the yield and index operations
@@ -267,22 +267,24 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
- assert(consumer.isInput(fusedOperand) &&
+ assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
// Compute the fused operands list and indexing maps.
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
SmallVector<Type> fusedResultTypes;
SmallVector<AffineMap> fusedIndexMaps;
- fusedInputOperands.reserve(producer.getNumInputs() + consumer.getNumInputs());
- fusedOutputOperands.reserve(producer.getNumOutputs() +
- consumer.getNumOutputs());
- fusedResultTypes.reserve(producer.getNumOutputs() + consumer.getNumOutputs());
+ fusedInputOperands.reserve(producer.getNumDpsInputs() +
+ consumer.getNumDpsInputs());
+ fusedOutputOperands.reserve(producer.getNumDpsInits() +
+ consumer.getNumDpsInits());
+ fusedResultTypes.reserve(producer.getNumDpsInits() +
+ consumer.getNumDpsInits());
fusedIndexMaps.reserve(producer->getNumOperands() +
consumer->getNumOperands());
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
- auto consumerInputs = consumer.getInputOperands();
+ auto consumerInputs = consumer.getDpsInputOperands();
auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
return operand == fusedOperand;
});
@@ -294,7 +296,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
// 4. Splice in producer's input operands/maps.
AffineMap producerResultIndexMap =
producer.getIndexingMapMatchingResult(producerResult);
- for (OpOperand *opOperand : producer.getInputOperands()) {
+ for (OpOperand *opOperand : producer.getDpsInputOperands()) {
fusedInputOperands.push_back(opOperand->get());
// Compute indexing maps for the producer args in the fused operation.
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
@@ -311,7 +313,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
}
// 6. Collect all of the producer outputs.
- for (OpOperand *opOperand : producer.getOutputOperands()) {
+ for (OpOperand *opOperand : producer.getDpsInitOperands()) {
fusedOutputOperands.push_back(opOperand->get());
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
opOperand, producerResultIndexMap,
@@ -321,7 +323,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
}
// 7. All of consumer's output operands (skip operands: added by the builder).
- for (OpOperand *opOperand : consumer.getOutputOperands()) {
+ for (OpOperand *opOperand : consumer.getDpsInitOperands()) {
fusedOutputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
fusedResultTypes.push_back(opOperand->get().getType());
@@ -721,8 +723,8 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
rewriter.setInsertionPoint(genericOp);
SmallVector<Value> expandedOpOperands;
- expandedOpOperands.reserve(genericOp.getNumInputs());
- for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ expandedOpOperands.reserve(genericOp.getNumDpsInputs());
+ for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
if (opOperand == fusableOpOperand) {
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
: collapsingReshapeOp.getSrc());
@@ -756,7 +758,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
Location loc = genericOp.getLoc();
SmallVector<Value> outputs;
- for (OpOperand *opOperand : genericOp.getOutputOperands()) {
+ for (OpOperand *opOperand : genericOp.getDpsInitOperands()) {
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
RankedTensorType expandedOutputType =
@@ -805,7 +807,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(
genericOp.getMatchingIndexingMap(
- genericOp.getOutputOperand(resultNumber)),
+ genericOp.getDpsInitOperand(resultNumber)),
expansionInfo);
resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
genericOp.getLoc(), opResult.getType(),
@@ -834,7 +836,7 @@ class FoldWithProducerReshapeOpByExpansion
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
tensor::CollapseShapeOp reshapeOp =
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
@@ -888,7 +890,7 @@ struct FoldReshapeWithGenericOpByExpansion
if (!isFusableWithReshapeByDimExpansion(
producer,
- producer.getOutputOperand(producerResult.getResultNumber()))) {
+ producer.getDpsInitOperand(producerResult.getResultNumber()))) {
return rewriter.notifyMatchFailure(
reshapeOp, "failed preconditions of fusion with producer generic op");
}
@@ -900,7 +902,7 @@ struct FoldReshapeWithGenericOpByExpansion
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
producer, reshapeOp,
- producer.getOutputOperand(producerResult.getResultNumber()), rewriter);
+ producer.getDpsInitOperand(producerResult.getResultNumber()), rewriter);
if (!replacementValues) {
return rewriter.notifyMatchFailure(reshapeOp,
"fusion by expansion failed");
@@ -1046,7 +1048,7 @@ static SmallVector<ReassociationIndices>
getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
ArrayRef<ReassociationIndices> reassociation) {
// Some basic checks for this fusion to be valid.
- if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
+ if (!genericOp.hasTensorSemantics() || genericOp.getNumDpsInits() != 1)
return {};
if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
@@ -1416,8 +1418,8 @@ static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
Location loc = genericOp->getLoc();
// Get the input operands.
- auto inputOperands = llvm::to_vector(
- llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) {
+ auto inputOperands = llvm::to_vector(llvm::map_range(
+ genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
rewriter);
}));
@@ -1425,9 +1427,9 @@ static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
// Get the output operands and result types.
SmallVector<Type> resultTypes;
SmallVector<Value> outputOperands;
- resultTypes.reserve(genericOp.getNumOutputs());
- outputOperands.reserve(genericOp.getNumOutputs());
- for (OpOperand *output : genericOp.getOutputOperands()) {
+ resultTypes.reserve(genericOp.getNumDpsInits());
+ outputOperands.reserve(genericOp.getNumDpsInits());
+ for (OpOperand *output : genericOp.getDpsInitOperands()) {
Value newOutput =
getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
@@ -1575,7 +1577,7 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
- for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
Operation *def = opOperand->get().getDefiningOp();
TypedAttr constantAttr;
auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
@@ -1616,9 +1618,9 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
SmallVector<Value> fusedOperands;
SmallVector<Location> fusedLocs{genericOp.getLoc()};
fusedIndexMaps.reserve(genericOp->getNumOperands());
- fusedOperands.reserve(genericOp.getNumInputs());
- fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
- for (OpOperand *inputOperand : genericOp.getInputOperands()) {
+ fusedOperands.reserve(genericOp.getNumDpsInputs());
+ fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
+ for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
if (inputOperand == opOperand)
continue;
Value inputValue = inputOperand->get();
@@ -1627,7 +1629,7 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
fusedOperands.push_back(inputValue);
fusedLocs.push_back(inputValue.getLoc());
}
- for (OpOperand *outputOperand : genericOp.getOutputOperands())
+ for (OpOperand *outputOperand : genericOp.getDpsInitOperands())
fusedIndexMaps.push_back(
genericOp.getMatchingIndexingMap(outputOperand));
@@ -1687,7 +1689,7 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
rewriter.startRootUpdate(op);
bool modifiedOutput = false;
Location loc = op.getLoc();
- for (OpOperand *opOperand : op.getOutputOperands()) {
+ for (OpOperand *opOperand : op.getDpsInitOperands()) {
if (!op.payloadUsesValueFromOperand(opOperand)) {
Value operandVal = opOperand->get();
auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
@@ -1735,7 +1737,7 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
return failure();
bool fillFound = false;
Block &payload = genericOp.getRegion().front();
- for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
if (!genericOp.payloadUsesValueFromOperand(opOperand))
continue;
FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
index 866b411ba335d..473616558c9b2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
@@ -110,7 +110,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
// Clone the generic op.
auto clonedOp =
cast<linalg::GenericOp>(rewriter.clone(*linalgOp.getOperation()));
- clonedOp.setOutputOperand(resultNumber, slice.getResult());
+ clonedOp.setDpsInitOperand(resultNumber, slice.getResult());
// Insert it back into the result of the fill.
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 5c2c987710936..2d51b8d03d291 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -150,7 +150,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
// fully dynamic at construction time.
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(producer->getNumResults());
- for (OpOperand *operand : producer.getOutputOperands()) {
+ for (OpOperand *operand : producer.getDpsInitOperands()) {
auto tensorType = operand->get().getType().dyn_cast<RankedTensorType>();
if (!tensorType)
continue;
@@ -211,7 +211,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
"expected linalg op with buffer semantics");
assert(consumer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
- if (producer.getNumOutputs() != 1) {
+ if (producer.getNumDpsInits() != 1) {
LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
return false;
}
@@ -443,7 +443,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
b.setInsertionPoint(consumerOp);
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
OpOperand *opOperand =
- producerOp.getOutputOperand(producerOpResult.getResultNumber());
+ producerOp.getDpsInitOperand(producerOpResult.getResultNumber());
LinalgOp fusedProducer =
fuse(b, producerOp, producerOp.getMatchingIndexingMap(opOperand),
consumerOpOperand);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 1dd6c35723b9a..d0cf68478beed 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -69,7 +69,7 @@ getTiledProducerLoops(OpResult producerResult,
// Get the indexing map of the `producerOp` output operand that matches
// ´producerResult´.
AffineMap producerIndexingMap = producerOp.getMatchingIndexingMap(
- producerOp.getOutputOperand(producerResult.getResultNumber()));
+ producerOp.getDpsInitOperand(producerResult.getResultNumber()));
// Keep only the tiled result slice dimensions of `producerIndexingMap`.
AffineMap tiledProducerIndexingSubMap =
@@ -173,14 +173,14 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
// output operand.
if (iterArg) {
OpOperand *outputOperand =
- producerOp.getOutputOperand(producerResult.getResultNumber());
+ producerOp.getDpsInitOperand(producerResult.getResultNumber());
iterArg->set(outputOperand->get());
tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult();
}
// Clone the producer using the tiled producer operands.
TypeRange resultTypes = ValueRange(tiledOperands)
- .take_back(producerOp.getNumOutputs())
+ .take_back(producerOp.getNumDpsInits())
.getTypes();
LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index ea6ce398b1a77..da43b49b8a5b1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -50,8 +50,8 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
if (failed(generalizeNamedOpPrecondition(linalgOp)))
return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
- SmallVector<Value> inputs = linalgOp.getInputOperands();
- SmallVector<Value> outputs = linalgOp.getOutputOperands();
+ SmallVector<Value> inputs = linalgOp.getDpsInputOperands();
+ SmallVector<Value> outputs = linalgOp.getDpsInitOperands();
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray();
SmallVector<Type> resultTypes = linalgOp.hasTensorSemantics()
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 7515e3006b94d..baeb5c2952e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -111,7 +111,7 @@ struct HoistingAnalysis {
static bool isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padOp) {
for (OpOperand &use : padOp.getResult().getUses()) {
auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
- if (!linalgUser || !linalgUser.isInput(&use)) {
+ if (!linalgUser || !linalgUser.isDpsInput(&use)) {
LLVM_DEBUG(DBGS() << "Found a use of " << *(padOp)
<< "\nthat is not an input tensor of a LinalgOp, "
<< "cannot hoist\n"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index 4ea889d94e522..c9b118ba0825e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -41,9 +41,9 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
SmallVector<size_t> scalarOperands;
SmallVector<AffineMap> newIndexingMaps;
SmallVector<Value> newOperands;
- for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
- if (genericOp.isInput(opOperand) && map.isConstant()) {
+ if (genericOp.isDpsInput(opOperand) && map.isConstant()) {
scalarOperands.emplace_back(opOperand->getOperandNumber());
} else {
newIndexingMaps.emplace_back(map);
@@ -54,7 +54,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
if (scalarOperands.empty())
return failure();
- for (OpOperand *opOperand : genericOp.getOutputOperands())
+ for (OpOperand *opOperand : genericOp.getDpsInitOperands())
newIndexingMaps.emplace_back(genericOp.getMatchingIndexingMap(opOperand));
Location loc = genericOp->getLoc();
@@ -70,7 +70,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
rewriter.setInsertionPointToStart(body);
for (auto idx : llvm::reverse(scalarOperands)) {
- OpOperand *opOperand = genericOp.getInputOperand(idx);
+ OpOperand *opOperand = genericOp.getDpsInputOperand(idx);
AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
SmallVector<int64_t> indices = map.getConstantResults();
SmallVector<Value> indicesValues;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 4fc914905fd7c..b8d0b09b0c90e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -138,7 +138,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
// TODO: Avoid the loads if the corresponding argument of the
// region has no uses.
// 1.a. Emit load from input operand or for scalars access the operand itself.
- for (OpOperand *inputOperand : linalgOp.getInputOperands()) {
+ for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
if (linalgOp.isScalar(inputOperand)) {
indexedValues.push_back(inputOperand->get());
continue;
@@ -149,7 +149,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
}
// 1.b. Emit load from output views.
- for (OpOperand *outputOperand : linalgOp.getOutputOperands()) {
+ for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) {
SmallVector<Value> indexing = makeCanonicalAffineApplies(
b, loc, linalgOp.getMatchingIndexingMap(outputOperand), allIvsPlusDims);
indexedValues.push_back(
@@ -161,7 +161,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
// 3. Emit store.
SmallVector<SmallVector<Value>, 8> indexing;
SmallVector<Value> outputBuffers;
- for (OpOperand *outputOperand : linalgOp.getOutputOperands()) {
+ for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) {
if (!outputOperand->get().getType().isa<MemRefType>())
continue;
indexing.push_back(makeCanonicalAffineApplies(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
index c2a8dbe000c2d..cabd342d86c09 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
@@ -108,9 +108,9 @@ struct SimplifyDepthwiseConvOp
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
PatternRewriter &rewriter) const override {
Operation *operation = op.getOperation();
- Value input = op.getInputOperand(0)->get();
- Value kernel = op.getInputOperand(1)->get();
- Value init = op.getOutputOperand(0)->get();
+ Value input = op.getDpsInputOperand(0)->get();
+ Value kernel = op.getDpsInputOperand(1)->get();
+ Value init = op.getDpsInitOperand(0)->get();
auto stride = op.getStrides();
auto dilation = op.getDilations();
@@ -128,11 +128,11 @@ struct SimplifyDepthwiseConvQOp
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
PatternRewriter &rewriter) const override {
Operation *operation = op.getOperation();
- Value input = op.getInputOperand(0)->get();
- Value kernel = op.getInputOperand(1)->get();
- Value iZp = op.getInputOperand(2)->get();
- Value kZp = op.getInputOperand(3)->get();
- Value init = op.getOutputOperand(0)->get();
+ Value input = op.getDpsInputOperand(0)->get();
+ Value kernel = op.getDpsInputOperand(1)->get();
+ Value iZp = op.getDpsInputOperand(2)->get();
+ Value kZp = op.getDpsInputOperand(3)->get();
+ Value init = op.getDpsInitOperand(0)->get();
auto stride = op.getStrides();
auto dilation = op.getDilations();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 6f642eafda131..1d966b31c574e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -339,7 +339,7 @@ promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
else
opViews.push_back(
(*promotedBuffersAndViews)[operandNumber].partialLocalView);
- if (operandNumber >= op.getNumInputs())
+ if (operandNumber >= op.getNumDpsInputs())
writebackViews.emplace_back(std::make_pair(
opOperand.get(),
(*promotedBuffersAndViews)[operandNumber].partialLocalView));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 92d04c1cca5ba..32d05c5acbe6c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -96,7 +96,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
SmallVector<Value> newInputs;
SmallVector<AffineMap> newMaps;
// Calculate the new shapes and indexing maps of the input operands.
- for (OpOperand *operand : op.getInputOperands()) {
+ for (OpOperand *operand : op.getDpsInputOperands()) {
AffineMap map = op.getMatchingIndexingMap(operand);
SmallVector<int64_t> newShape;
SmallVector<AffineExpr> exprs;
@@ -151,8 +151,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
// Calculate the new output map and shape, we insert the new dimension based
// on the index returned by `controlSplitReductionFn`.
SmallVector<int64_t> newOutputShape;
- AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getOutputOperand(0));
- ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
+ AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
+ ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
SmallVector<AffineExpr> outputExpr;
for (unsigned idx :
llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
@@ -229,7 +229,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
auto reduction = b.create<GenericOp>(
loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
- SmallVector<Value>{op.getOutputOperands()}, reductionMaps,
+ SmallVector<Value>{op.getDpsInitOperands()}, reductionMaps,
reductionIteratorTypes,
[reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
Operation *clonedReductionOp = b.clone(*reductionOp);
@@ -317,7 +317,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
return b.notifyMatchFailure(op, "unknown reduction neutral");
// TODO: relax this when multi-reduction support is available.
- if (op.getNumOutputs() != static_cast<int64_t>(neutralElements.size()))
+ if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size()))
return b.notifyMatchFailure(op, "expect one reduction per output");
// Rewrite part.
@@ -337,11 +337,11 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// For now assume outputs are 1-1 with reduction neutralElements.
// TODO: generalize when multi-reduction support is available.
SmallVector<Value> newOutputs;
- newOutputs.reserve(op.getNumOutputs());
+ newOutputs.reserve(op.getNumDpsInits());
SmallVector<Operation *> emptyOrAllocTensorOps;
SmallVector<linalg::FillOp> fillOps;
- fillOps.reserve(op.getNumOutputs());
- for (auto it : llvm::zip(op.getOutputOperands(), neutralElements)) {
+ fillOps.reserve(op.getNumDpsInits());
+ for (auto it : llvm::zip(op.getDpsInitOperands(), neutralElements)) {
Value rankedTensor = std::get<0>(it)->get();
auto t = rankedTensor.getType().cast<RankedTensorType>();
RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
@@ -367,7 +367,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// Reindex existing input indexings: k -> k * splitFactor + k'.
SmallVector<AffineMap> newMaps;
newMaps.reserve(op->getNumOperands() + 1);
- for (OpOperand *o : op.getInputOperands())
+ for (OpOperand *o : op.getDpsInputOperands())
newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
// Provision a new indexing for the shape-only tensor.
auto nDims = op.getNumLoops() + 1;
@@ -378,13 +378,13 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// TODO: a subset of these may not reduce along reducePos and should be
// reindexed: k -> k * splitFactor + k', when multi-reduction support is
// available.
- for (OpOperand *o : op.getOutputOperands())
+ for (OpOperand *o : op.getDpsInitOperands())
newMaps.push_back(insertParallelDim(op, *o, reductionDimPos,
reductionDimSize / splitFactor));
// Step 3. Handle operands.
// Compute the new input tensors.
- SmallVector<Value> newInputs(op.getInputOperands());
+ SmallVector<Value> newInputs(op.getDpsInputOperands());
// Add a single shape-only tensor to carry the dimensions without resorting to
// more complex inversions.
newInputs.push_back(b.create<tensor::EmptyOp>(
@@ -413,7 +413,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// TODO: all results can be handled in a single GenericOp, when
// multi-reduction support is available.
SmallVector<LinalgOp> results;
- for (auto it : llvm::zip(genericOp->getResults(), op.getOutputOperands(),
+ for (auto it : llvm::zip(genericOp->getResults(), op.getDpsInitOperands(),
combinerOps)) {
Value reindexedOutput = std::get<0>(it);
Value originalOutput = std::get<1>(it)->get();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index b5d83b864bfc8..5937da3a3200c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -324,7 +324,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
Operation *clonedOp = b.clone(*op.getOperation());
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
if (destinationStyleOp) {
- for (OpOperand *outOperand : destinationStyleOp.getOutputOperands()) {
+ for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
auto *it = llvm::find(dest, outOperand->get());
assert(it != dest.end() && "dest operand not found in dest");
unsigned destNum = std::distance(dest.begin(), it);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 86dae1421387b..c843f0f400793 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -60,12 +60,12 @@ static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
Location loc = terminator->getLoc();
for (const auto &operand : llvm::enumerate(terminator->getOperands())) {
Value toStore = map.lookupOrDefault(operand.value());
- OpOperand *storeInto = linalgOp.getOutputOperand(operand.index());
+ OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index());
auto indices = getIndicesForAccess(
b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);
- b.create<memref::StoreOp>(loc, toStore,
- linalgOp.getOutputOperand(operand.index())->get(),
- indices);
+ b.create<memref::StoreOp>(
+ loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(),
+ indices);
}
return success();
}
@@ -152,7 +152,7 @@ struct LinalgOpTilingInterface
return makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr);
}));
- OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
+ OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
SliceParameters sliceParams = computeSliceParameters(
b, loc, outOperand->get(), sizes,
linalgOp.getMatchingIndexingMap(outOperand), offsets,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index eee454b9aec0e..415d87abc3d7b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -138,7 +138,7 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
OpOperand *currOpOperand = opOperand;
while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) {
OpResult result = currOpOperand->get().cast<OpResult>();
- currOpOperand = linalgOp.getOutputOperand(result.getResultNumber());
+ currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber());
}
// Fail if `currOpOperand` is not defined by an ExtractSliceOp.
@@ -222,7 +222,7 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
// Clone `opToPad` to operate on the statically padded shapes.
auto resultTensorTypes =
- ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
+ ValueRange(newOperands).take_back(opToPad.getNumDpsInits()).getTypes();
paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
// Recover the slice out of the new static results. This keeps the original
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 356ba00a98204..d565efb30241d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -150,7 +150,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
static Operation *matchLinalgReduction(OpOperand *outputOperand) {
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
unsigned outputPos =
- outputOperand->getOperandNumber() - linalgOp.getNumInputs();
+ outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
// Only single combiner operations are supported for now.
SmallVector<Operation *, 4> combinerOps;
if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
@@ -263,7 +263,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op,
// TODO: use a map.
Value vectorValue = bvm.lookup(outputs.value());
Value newResult = buildVectorWrite(
- b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
+ b, vectorValue, linalgOp.getDpsInitOperand(outputs.index()));
if (newResult)
newResults.push_back(newResult);
}
@@ -435,12 +435,12 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
SmallVector<std::pair<Value, Value>> reductionOperands;
for (Value operand : op->getOperands()) {
auto arg = operand.dyn_cast<BlockArgument>();
- if (!arg || arg.getArgNumber() < linalgOp.getNumInputs())
+ if (!arg || arg.getArgNumber() < linalgOp.getNumDpsInputs())
continue;
SmallVector<Operation *> reductionOps;
Value reduceValue = matchReduction(
linalgOp.getRegionOutputArgs(),
- arg.getArgNumber() - linalgOp.getNumInputs(), reductionOps);
+ arg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
if (!reduceValue)
continue;
reductionOperands.push_back(std::make_pair(reduceValue, operand));
@@ -517,7 +517,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
- if (linalgOp.getNumOutputs() == 0)
+ if (linalgOp.getNumDpsInits() == 0)
return failure();
// TODO: the common vector shape is equal to the static loop sizes only when
@@ -540,7 +540,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
// if (linalgOp.getShape(&opOperand).empty()) {
// readType = VectorType::get({}, bbarg.getType());
// } else {
- if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
+ if (opOperand->getOperandNumber() < linalgOp.getNumDpsInputs()) {
map = inverseAndBroadcastProjectedPermutation(
linalgOp.getMatchingIndexingMap(opOperand));
readType = VectorType::get(commonVectorShape,
@@ -615,7 +615,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
LDBG("reduction precondition failed: no reduction iterator");
return failure();
}
- for (OpOperand *opOperand : op.getOutputOperands()) {
+ for (OpOperand *opOperand : op.getDpsInitOperands()) {
AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
if (indexingMap.isPermutation())
continue;
@@ -1426,11 +1426,11 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
: StructuredGenerator<LinalgOp>(builder, linalgOp), strideW(strideW),
dilationW(dilationW) {
// Determine whether `linalgOp` can be generated with this generator
- if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
+ if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
return;
- lhsShaped = linalgOp.getInputOperand(0)->get();
- rhsShaped = linalgOp.getInputOperand(1)->get();
- resShaped = linalgOp.getOutputOperand(0)->get();
+ lhsShaped = linalgOp.getDpsInputOperand(0)->get();
+ rhsShaped = linalgOp.getDpsInputOperand(1)->get();
+ resShaped = linalgOp.getDpsInitOperand(0)->get();
lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
resShapedType = resShaped.getType().dyn_cast<ShapedType>();
@@ -1442,7 +1442,7 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
return;
// Check for reduction `add` preceded by `mul`.
- Operation *reduceOp = matchLinalgReduction(linalgOp.getOutputOperand(0));
+ Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
if (!reduceOp)
return;
llvm::Optional<vector::CombiningKind> maybeKind;
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index af5a2012429ba..ce15c6767b24b 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -179,7 +179,7 @@ bool isElementwise(LinalgOp op) {
return false;
// TODO: relax the restrictions on indexing map.
- for (OpOperand *opOperand : op.getOutputOperands()) {
+ for (OpOperand *opOperand : op.getDpsInitOperands()) {
if (!op.getMatchingIndexingMap(opOperand).isPermutation())
return false;
}
@@ -357,7 +357,7 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
if (!linalgOp)
break;
OpResult opResult = current.cast<OpResult>();
- current = linalgOp.getOutputOperand(opResult.getResultNumber())->get();
+ current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
}
auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
@@ -479,7 +479,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
"they are null entries");
SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
? SmallVector<Value>{}
- : linalgOp.getOutputOperands();
+ : linalgOp.getDpsInitOperands();
SmallVector<Value, 4> lbs, ubs, steps;
unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
@@ -490,7 +490,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
"expect the number of output tensors and iter args to match");
SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
if (!iterArgs.empty()) {
- operandValuesToUse = linalgOp.getInputOperands();
+ operandValuesToUse = linalgOp.getDpsInputOperands();
operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
}
return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
@@ -520,7 +520,7 @@ void GenerateLoopNest<AffineForOp>::doit(
ArrayRef<linalg::ProcInfo> /*procInfo*/) {
SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
? SmallVector<Value>{}
- : linalgOp.getOutputOperands();
+ : linalgOp.getDpsInitOperands();
assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
SmallVector<Value, 4> lbs, ubs, steps;
unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
@@ -686,7 +686,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
ArrayRef<linalg::ProcInfo> procInfo) {
SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
? SmallVector<Value>{}
- : linalgOp.getOutputOperands();
+ : linalgOp.getDpsInitOperands();
assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
// This function may be passed more iterator types than ranges.
assert(iteratorTypes.size() >= loopRanges.size() &&
@@ -897,7 +897,7 @@ SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
if (op.hasBufferSemantics())
return {};
return llvm::to_vector(
- llvm::map_range(op.getOutputOperands(), [&](OpOperand *opOperand) {
+ llvm::map_range(op.getDpsInitOperands(), [&](OpOperand *opOperand) {
return operands[opOperand->getOperandNumber()].getType();
}));
}
@@ -911,7 +911,7 @@ SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
tensorResults.reserve(results.size());
// Insert a insert_slice for each output tensor.
unsigned resultIdx = 0;
- for (OpOperand *opOperand : op.getOutputOperands()) {
+ for (OpOperand *opOperand : op.getDpsInitOperands()) {
// TODO: use an interface/adaptor to avoid leaking position in
// `tiledOperands`.
Value outputTensor = operands[opOperand->getOperandNumber()];
@@ -965,7 +965,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
Type operandType = opOperand.get().getType();
if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() &&
- linalgOp.isOutput(&opOperand))) {
+ linalgOp.isDpsInit(&opOperand))) {
allSliceParams.push_back(llvm::None);
LLVM_DEBUG(llvm::dbgs()
<< ": not tiled: use shape: " << operandType << "\n");
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 8801fee1bdd6d..e3ab7220f748a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -394,7 +394,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
if (auto dstOp =
dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOp)) {
auto innerMostLoop = tilingResult.loops.back();
- SmallVector<Value> destinationTensors = dstOp.getOutputOperands();
+ SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
assert(destinationTensors.size() ==
innerMostLoop.getRegionIterArgs().size() &&
"unexpected number of outputs");
@@ -588,7 +588,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
.getDefiningOp<DestinationStyleOpInterface>()) {
scf::ForOp innerMostLoop = tileAndFuseResult.loops.back();
updateDestinationOperandsForTiledOp(
- rewriter, dstOp.getOutputOperand(resultNumber)->get(),
+ rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 48cec34ac23e8..4946256d810c7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -169,13 +169,13 @@ struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
- !isAlloc(op.getOutputOperand(0), /*isZero=*/false) || !isZeroYield(op))
+ !isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) || !isZeroYield(op))
return failure();
auto outputType = op.getResult(0).getType().cast<RankedTensorType>();
// Yielding zero on newly allocated (all-zero) sparse tensors can be
// optimized out directly (regardless of dynamic or static size).
if (getSparseTensorEncoding(outputType)) {
- rewriter.replaceOp(op, op.getOutputOperand(0)->get());
+ rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
return success();
}
// Incorporate zero value into allocation copy.
@@ -183,9 +183,9 @@ struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
return failure();
Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType());
AllocTensorOp a =
- op.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
+ op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); });
- rewriter.replaceOp(op, op.getOutputOperand(0)->get());
+ rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
return success();
}
};
@@ -212,31 +212,31 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
// Check consumer.
- if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
+ if (!op.hasTensorSemantics() || op.getNumDpsInputs() != 2 ||
op.getNumResults() != 1 ||
op.getNumParallelLoops() != op.getNumLoops() ||
- !op.getMatchingIndexingMap(op.getOutputOperand(0)).isIdentity() ||
- !op.getMatchingIndexingMap(op.getInputOperand(0)).isIdentity() ||
- !op.getMatchingIndexingMap(op.getInputOperand(1)).isIdentity())
+ !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
+ !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
+ !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
return failure();
// Find consuming OP2(sparse, other) or OP2(other, sparse). The other
// operand can be sparse or dense, since the point of this rewriting rule
// is detecting a situation in which *more* sparsity is introduced into
// a computation, be it already sparse or still dense.
unsigned other = 0;
- if (isSparseTensor(op.getInputOperand(0)))
+ if (isSparseTensor(op.getDpsInputOperand(0)))
other = 1;
- else if (!isSparseTensor(op.getInputOperand(1)))
+ else if (!isSparseTensor(op.getDpsInputOperand(1)))
return failure();
// Check producer.
auto prod = dyn_cast_or_null<GenericOp>(
- op.getInputOperand(other)->get().getDefiningOp());
+ op.getDpsInputOperand(other)->get().getDefiningOp());
if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
!prod.getResult(0).hasOneUse())
return failure();
// Sampling consumer and sum of multiplication chain producer.
- if (!isAlloc(op.getOutputOperand(0), /*isZero=*/false) ||
- !isAlloc(prod.getOutputOperand(0), /*isZero=*/true) ||
+ if (!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
+ !isAlloc(prod.getDpsInitOperand(0), /*isZero=*/true) ||
!isSampling(op) || !isSumOfMul(prod))
return failure();
// Modify operand structure of producer and consumer.
@@ -244,7 +244,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
SmallVector<Value> inputOps = prod.getInputs();
SmallVector<Value> outputOps = op.getOutputs();
SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
- inputOps.push_back(op.getInputOperand(1 - other)->get());
+ inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
// Fuse producer and consumer into a new generic op.
auto fusedOp = rewriter.create<GenericOp>(
@@ -277,12 +277,12 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
rewriter.create<linalg::YieldOp>(loc, last);
// Force initial value on merged allocation for dense outputs.
if (!getSparseTensorEncoding(op.getResult(0).getType())) {
- Value init = prod.getOutputOperand(0)
+ Value init = prod.getDpsInitOperand(0)
->get()
.getDefiningOp<AllocTensorOp>()
.getCopy();
AllocTensorOp a =
- op.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
+ op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); });
}
// Replace consumer with fused operation. Old producer
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index efef4ff4de6eb..f77fdb5534b81 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -311,7 +311,7 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
std::vector<unsigned> &topSort, unsigned exp,
OpOperand **sparseOut,
unsigned &outerParNest) {
- OpOperand *lhs = op.getOutputOperand(0);
+ OpOperand *lhs = op.getDpsInitOperand(0);
unsigned tensor = lhs->getOperandNumber();
auto enc = getSparseTensorEncoding(lhs->get().getType());
// An non-annotated output tensor is assumed dense, and becomes a random
@@ -410,7 +410,7 @@ static Value getCustomRedId(Operation *op) {
static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
linalg::GenericOp op) {
Location loc = op.getLoc();
- assert(op.getNumOperands() == op.getNumInputs() + 1);
+ assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
codegen.loopEmitter.initializeLoopEmit(
builder, loc,
@@ -425,7 +425,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
Value tensor) -> Value {
// Must not be a sparse tensor.
assert(!getSparseTensorEncoding(tensor.getType()));
- OpOperand *lhs = op.getOutputOperand(0);
+ OpOperand *lhs = op.getDpsInitOperand(0);
// Two output tensors references should pointed to the same object.
assert(lhs->get() == tensor);
bool isInit = op.isInitTensor(lhs);
@@ -626,7 +626,7 @@ static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder,
return;
}
// Store during insertion.
- OpOperand *t = op.getOutputOperand(0);
+ OpOperand *t = op.getDpsInitOperand(0);
if (t == codegen.sparseOut) {
if (!rhs) {
// Only unary and binary are allowed to return uninitialized rhs
@@ -768,7 +768,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
// All exhausted at this level (atLevel denotes exactly at this level).
if (!atLevel)
return;
- OpOperand *lhs = op.getOutputOperand(0);
+ OpOperand *lhs = op.getDpsInitOperand(0);
if (lhs == &t) {
// Start or end a scalarized reduction
if (atStart) {
@@ -1248,7 +1248,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
/// Converts the result computed by the sparse kernel into the required form.
static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
linalg::GenericOp op) {
- OpOperand *lhs = op.getOutputOperand(0);
+ OpOperand *lhs = op.getDpsInitOperand(0);
Type resType = lhs->get().getType();
if (getSparseTensorEncoding(resType)) {
// The sparse tensor rematerializes from the original sparse tensor's
@@ -1279,7 +1279,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
PatternRewriter &rewriter) const override {
// Detects sparse annotations and translate the per-dimension sparsity
// information for all tensors to loop indices in the kernel.
- if (op.getNumOutputs() != 1)
+ if (op.getNumDpsInits() != 1)
return failure();
unsigned numTensors = op->getNumOperands();
unsigned numLoops = op.getNumLoops();
@@ -1349,7 +1349,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// sparse input tensor in succession until an acylic
// iteration graph results.
std::vector<unsigned> topSort;
- for (OpOperand *t : op.getInputOperands()) {
+ for (OpOperand *t : op.getDpsInputOperands()) {
unsigned tensor = t->getOperandNumber();
Value tval = t->get();
auto srcEnc = getSparseTensorEncoding(tval.getType());
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
index d89914e6585a9..b334eedf61ef6 100644
--- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
+++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
@@ -27,7 +27,7 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
cast<DestinationStyleOpInterface>(op);
SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
- for (OpOperand *operand : dstStyleOp.getOutputOperands()) {
+ for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) {
Type type = operand->get().getType();
if (type.isa<MemRefType>()) {
outputBufferOperands.push_back(operand);
@@ -41,11 +41,11 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
}
// Expect at least one output operand.
- int64_t numInputs = dstStyleOp.getNumInputs();
- int64_t numOutputs = dstStyleOp.getNumOutputs();
- if (numOutputs == 0)
+ int64_t numInputs = dstStyleOp.getNumDpsInputs();
+ int64_t numInits = dstStyleOp.getNumDpsInits();
+ if (numInits == 0)
return op->emitOpError("expected at least one output operand");
- if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
+ if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numInits)))
return failure();
// Verify the number of results matches the number of output tensors.
if (op->getNumResults() != outputTensorOperands.size())
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index 5807726dd73b4..4bd6d43f0f6ef 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -26,7 +26,7 @@ static void addOperands(Operation *op, SetVector<Value> &operandSet) {
return;
TypeSwitch<Operation *, void>(op)
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
- SmallVector<Value> inputOperands{linalgOp.getInputOperands()};
+ SmallVector<Value> inputOperands{linalgOp.getDpsInputOperands()};
operandSet.insert(inputOperands.begin(), inputOperands.end());
})
.Default([&](Operation *operation) {
@@ -147,7 +147,7 @@ struct TestLinalgElementwiseFusion
if (expandOp->hasOneUse()) {
OpOperand &use = *expandOp->getUses().begin();
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
- if (linalgOp && linalgOp.isOutput(&use))
+ if (linalgOp && linalgOp.isDpsInit(&use))
return true;
}
return false;
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 3c62496f09d24..561b5bca9f85f 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -56,7 +56,7 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
changed = true;
} else if (opOperand.get().getType().isa<RankedTensorType>()) {
// Tile and Fuse tensor input.
- if (opOperand.getOperandNumber() >= linalgOp.getNumInputs())
+ if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs())
continue;
auto info = fuseProducerOfTensor(b, opOperand);
if (failed(info))
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e4206f9c0251e..84bbe241b1a03 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2836,7 +2836,7 @@ def TestLinalgConvOp :
return "";
}
- std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
}
@@ -2896,7 +2896,7 @@ def TestLinalgFillOp :
return "";
}
- std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
int64_t getNumOperands = this->getNumOperands();
return {getNumOperands - 1, getNumOperands};
}
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 51557f519772a..de15df55c2bcf 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -236,7 +236,7 @@ structured_op: !LinalgStructuredOpConfig
scalar_arg: value
# IMPL: Test3Op::getIteratorTypesArray() {
-# IMPL-NEXT: int64_t rank = getRank(getOutputOperand(0));
+# IMPL-NEXT: int64_t rank = getRank(getDpsInitOperand(0));
# IMPL: Test3Op::getIndexingMaps() {
# IMPL-NEXT: MLIRContext *context = getContext();
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index d531d8b45160a..0a482cc28eac5 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -563,7 +563,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
return regionBuilder;
}
- std::pair<int64_t, int64_t> getOutputsPositionRange() {{
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {{
int64_t getNumOperands = this->getNumOperands();
return {{getNumOperands - 1, getNumOperands};
}
@@ -608,7 +608,7 @@ SmallVector<StringRef> {0}::getIteratorTypesArray() {{
static const char rankPolyStructuredOpIteratorTypesFormat[] =
R"FMT(
SmallVector<StringRef> {0}::getIteratorTypesArray() {{
- int64_t rank = getRank(getOutputOperand(0));
+ int64_t rank = getRank(getDpsInitOperand(0));
return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
}
)FMT";
@@ -661,7 +661,7 @@ void {0}::getEffects(SmallVectorImpl<
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
if (hasTensorSemantics()) return;
getGenericEffectsImpl(effects,
- getOperation()->getResults(), getInputOperands(), getOutputOperands());
+ getOperation()->getResults(), getDpsInputOperands(), getDpsInitOperands());
}
)FMT";
More information about the Mlir-commits
mailing list