[Mlir-commits] [mlir] a7cccb9 - [mlir] Simplify DestinationStyleOpInterface.
Alexander Belyaev
llvmlistbot at llvm.org
Mon Oct 17 03:45:02 PDT 2022
Author: Alexander Belyaev
Date: 2022-10-17T12:43:41+02:00
New Revision: a7cccb9cbb2b9954684cbea37615303a59719973
URL: https://github.com/llvm/llvm-project/commit/a7cccb9cbb2b9954684cbea37615303a59719973
DIFF: https://github.com/llvm/llvm-project/commit/a7cccb9cbb2b9954684cbea37615303a59719973.diff
LOG: [mlir] Simplify DestinationStyleOpInterface.
Differential Revision: https://reviews.llvm.org/D135348
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/CAPI/Dialect/Linalg.cpp
mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 69871faa5a2cc..995ced53d449b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -317,7 +317,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- if (!$_op.isOutputTensor(opOperand))
+ if (!$_op.isOutput(opOperand))
return false;
return payloadUsesValueFromOperand(opOperand);
}]
@@ -606,7 +606,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return $_op.getInputAndOutputOperands();
+ OpOperandVector result;
+ result.reserve($_op->getNumOperands());
+ llvm::transform(
+ this->getOperation()->getOpOperands(),
+ std::back_inserter(result),
+ [](OpOperand &opOperand) { return &opOperand; });
+ return result;
}]
>,
//===------------------------------------------------------------------===//
@@ -684,13 +690,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<int64_t> res;
- // MLIR currently does not support dependent interfaces or interface
- // inheritance. By construction all ops with StructuredOpInterface must
- // implement DestinationStyleOpInterface.
- // TODO: reevalute the need for a cast when a better mechanism exists.
- auto iface = cast<DestinationStyleOpInterface>(*this->getOperation());
- for (OpOperand *opOperand : iface.getInputAndOutputOperands())
- llvm::append_range(res, getShape(opOperand));
+ for (OpOperand &opOperand : this->getOperation()->getOpOperands())
+ llvm::append_range(res, getShape(&opOperand));
return res;
}]
>,
@@ -779,31 +780,16 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
// TODO: reevalute the need for a cast when a better mechanism exists.
//========================================================================//
- ValueRange getInputs() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getInputs();
- }
-
int64_t getNumInputs() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getNumInputs();
}
- ValueRange getOutputs() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getOutputs();
- }
-
int64_t getNumOutputs() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getNumOutputs();
}
- int64_t getNumInputsAndOutputs() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumInputsAndOutputs();
- }
-
OpOperandVector getInputOperands() {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getInputOperands();
@@ -814,14 +800,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
.getInputOperand(i);
}
- OpOperandVector getInputBufferOperands() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getInputBufferOperands();
- }
-
- OpOperandVector getInputTensorOperands() {
+ void setOutputOperand(int64_t i, Value value) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getInputTensorOperands();
+ .setOutputOperand(i, value);
}
OpOperandVector getOutputOperands() {
@@ -834,44 +815,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
.getOutputOperand(i);
}
- void setOutputOperand(int64_t i, Value value) {
+ bool isInput(OpOperand *opOperand) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .setOutputOperand(i, value);
+ .isInput(opOperand);
}
- OpOperandVector getOutputBufferOperands() {
+ bool isOutput(OpOperand *opOperand) {
return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getOutputBufferOperands();
- }
-
- OpOperandVector getOutputTensorOperands() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getOutputTensorOperands();
- }
-
- SmallVector<MemRefType> getOutputBufferTypes() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getOutputBufferTypes();
- }
-
- SmallVector<RankedTensorType> getOutputTensorTypes() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getOutputTensorTypes();
- }
-
- OpOperandVector getInputAndOutputOperands() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getInputAndOutputOperands();
- }
-
- bool isInputTensor(OpOperand *opOperand) {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .isInputTensor(opOperand);
- }
-
- bool isOutputTensor(OpOperand *opOperand) {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .isOutputTensor(opOperand);
+ .isOutput(opOperand);
}
bool isScalar(OpOperand *opOperand) {
@@ -928,331 +879,185 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
let verifyWithRegions = 1;
}
-// The 'DestinationStyleOpInterface' provides access to the methods relevant
-// for destination-style ops. A destination-style operation has 'n' input
-// arguments and 'm' output arguments. Each op that wants to implement
-// DestinationStyleOpInterface needs to define getInputs() and getOutputs()
-// methods.
+// Ops that are in destination style have designated output operands, which act
+// as initial tensor values for the results of the operation or the output
+// buffers to which the results of the op will be written.
+//
+// Output operands must be tensors or memrefs. Input operands can have any
+// type. All non-output operands are inputs.
+
+// It is assumed that the output operands of the op are the operands at
+// position [start, end). The positions are defined by getOutputsPositionRange
+// method. All non-output operands are "inputs" of the DPS op.
+
+// If the op has "tensor semantics", then the input operands are either scalars
+// or tensors. The output operands are tensors and every tensor output is tied
+// to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output
+// tensor is tied to the i-th OpResult. The op may not have any additional
+// OpResults. Output operands and their tied OpResults have the same type.
+//
+// If the op has "buffer semantics", then the input operands are either memrefs
+// or other non-tensor types, e.g. scalar types. Furthermore, the output
+// operands are memrefs and the op has no results.
+//
+// Destination-passing style abstraction makes certain transformations easier.
+// For example, tiling implementation can extract/insert slices from/into the
+// destination of an op and use the resulting shaped value as an iter_arg in
+// the surrounding loop structure. As another example, bufferization does not
+// have to allocate new buffers for destinations (in case of in-place
+// bufferization) and can directly reuse the existing destination buffer.
+//
+// Example of a destination style op: `%r = tensor.insert_slice %t into %d`,
+// where `%t` is the single input and `%d` is the single output. `%d` is tied
+// to `%r`.
+//
+// Example of an op that is not in destination style: `%r = tensor.pad %t`.
+// This op is not in destination style because `%r` and `%t` have
diff erent
+// shape.
+//
+// Each op that wants to implement DestinationStyleOpInterface needs to define
+// the getOutputsPositionRange() method.
def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
let cppNamespace = "::mlir::linalg";
let methods = [
- //===------------------------------------------------------------------===//
- // Num input/output arguments handling.
- //===------------------------------------------------------------------===//
- // `getInputs` must be defined by each op that wants to implement the
- // DestinationStyleOpInterface.
+ // This method has to be defined for every DPS op.
InterfaceMethod<
- /*desc=*/[{
- Return the input shape operands.
- }],
- /*retTy=*/"ValueRange",
- /*methodName=*/"getInputs",
- /*args=*/(ins)
- >,
- // These special methods rely on `getInputs` and `getOutputs` being defined
- // by each op that wants to implement the DestinationStyleOpInterface.
- InterfaceMethod<
- /*desc=*/[{
- Return the number of inputs.
- }],
- /*retTy=*/"int64_t",
- /*methodName=*/"getNumInputs",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return $_op.getInputs().size();
- }]
- >,
- // `getOutputs` must be defined by each op that wants to implement the
- // DestinationStyleOpInterface.
- InterfaceMethod<
- /*desc=*/[{
- Return the output shape operands.
- }],
- /*retTy=*/"ValueRange",
- /*methodName=*/"getOutputs",
- /*args=*/(ins)
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the number of outputs.
- }],
- /*retTy=*/"int64_t",
- /*methodName=*/"getNumOutputs",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return $_op.getOutputs().size();
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the number of inputs and outputs.
- }],
- /*retTy=*/"int64_t",
- /*methodName=*/"getNumInputsAndOutputs",
+ /*desc=*/"Return start and end indices of the output operands range.",
+ /*retTy=*/"std::pair<int64_t, int64_t>",
+ /*methodName=*/"getOutputsPositionRange",
/*args=*/(ins),
/*methodBody=*/"",
- /*defaultImplementation=*/[{
- return this->getOperation()->getNumOperands();
- }]
+ /*defaultImplementation=*/""
>,
//===------------------------------------------------------------------===//
- // Input operands handling.
+ // Operands handling.
//===------------------------------------------------------------------===//
+ // The operand list is assumed to start with the input operands and end
+ // with the output operands. Therefore, all methods to access the inputs
+ // and outputs can be expressed if the number of output operands is know.
InterfaceMethod<
- /*desc=*/[{
- Return the input operands.
- }],
- /*retTy=*/"OpOperandVector",
- /*methodName=*/"getInputOperands",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- int64_t numInputs = getNumInputs();
- OpOperandVector result;
- result.reserve(numInputs);
- llvm::transform(
- this->getOperation()->getOpOperands().take_front(numInputs),
- std::back_inserter(result),
- [](OpOperand &opOperand) { return &opOperand; });
- return result;
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the `i`-th input operand.
- }],
- /*retTy=*/"OpOperand*",
- /*methodName=*/"getInputOperand",
- /*args=*/(ins "int64_t":$i),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(i >= 0 && i < getNumInputs());
- return &this->getOperation()->getOpOperand(i);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the subset of input operands that are of buffer type.
- }],
- /*retTy=*/"OpOperandVector",
- /*methodName=*/"getInputBufferOperands",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- OpOperandVector result;
- result.reserve(getNumInputs());
- llvm::copy_if(getInputOperands(),
- std::back_inserter(result),
- [](OpOperand *opOperand) {
- return opOperand->get().getType().template isa<MemRefType>();
- });
- return result;
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the subset of input operands that are of tensor type.
- }],
- /*retTy=*/"OpOperandVector",
- /*methodName=*/"getInputTensorOperands",
+ /*desc=*/"Return the number of outputs.",
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getNumOutputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- OpOperandVector result;
- result.reserve(getNumInputs());
- llvm::copy_if(getInputOperands(),
- std::back_inserter(result),
- [](OpOperand *opOperand) {
- return opOperand->get().getType().template isa<RankedTensorType>();
- });
- return result;
+ auto [start, end] = $_op.getOutputsPositionRange();
+ return end - start;
}]
>,
- //===------------------------------------------------------------------===//
- // Output operands handling.
- //===------------------------------------------------------------------===//
InterfaceMethod<
- /*desc=*/[{
- Return the output operands.
- }],
+ /*desc=*/"Return the output operands.",
/*retTy=*/"OpOperandVector",
/*methodName=*/"getOutputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- int64_t numOutputs = getNumOutputs();
+ auto [start, end] = $_op.getOutputsPositionRange();
+
OpOperandVector result;
- result.reserve(numOutputs);
- llvm::transform(
- this->getOperation()->getOpOperands()
- .take_back(numOutputs),
- std::back_inserter(result),
- [](OpOperand &opOperand) { return &opOperand; });
+ result.reserve(end - start);
+ for (int i = start; i < end; ++i)
+ result.push_back(&$_op->getOpOperand(i));
return result;
}]
>,
InterfaceMethod<
- /*desc=*/[{
- Return the `i`-th output operand.
- }],
+ /*desc=*/"Return the `i`-th output operand.",
/*retTy=*/"OpOperand*",
/*methodName=*/"getOutputOperand",
/*args=*/(ins "int64_t":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(i >= 0 && i < getNumOutputs());
- return &this->getOperation()->getOpOperand(getNumInputs() + i);
+ assert(i >= 0 && i < $_op.getNumOutputs());
+ auto [start, end] = $_op.getOutputsPositionRange();
+ return &$_op->getOpOperand(start + i);
}]
>,
InterfaceMethod<
- /*desc=*/[{
- Set the `i`-th output operand.
- }],
+ /*desc=*/"Set the `i`-th output operand.",
/*retTy=*/"void",
/*methodName=*/"setOutputOperand",
/*args=*/(ins "int64_t":$i, "Value":$value),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(i >= 0 && i < getNumOutputs());
- this->getOperation()->setOperand(getNumInputs() + i, value);
+ assert(i >= 0 && i < $_op.getNumOutputs());
+ auto [start, end] = $_op.getOutputsPositionRange();
+ $_op->setOperand(start + i, value);
}]
>,
InterfaceMethod<
- /*desc=*/[{
- Return the subset of output operands that are of buffer type.
- }],
- /*retTy=*/"OpOperandVector",
- /*methodName=*/"getOutputBufferOperands",
+ /*desc=*/"Return the number of inputs.",
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getNumInputs",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- OpOperandVector result;
- result.reserve(getNumOutputs());
- llvm::copy_if(getOutputOperands(),
- std::back_inserter(result),
- [](OpOperand *opOperand) {
- return opOperand->get().getType().template isa<MemRefType>();
- });
- return result;
+ return $_op.getNumOperands() - $_op.getNumOutputs();
}]
>,
InterfaceMethod<
- /*desc=*/[{
- Return the subset of output operands that are of tensor type.
- }],
+ /*desc=*/"Return the input operands.",
/*retTy=*/"OpOperandVector",
- /*methodName=*/"getOutputTensorOperands",
+ /*methodName=*/"getInputOperands",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
+ auto [start, end] = $_op.getOutputsPositionRange();
+ int64_t numOutputs = end - start;
+ int64_t numOperands = $_op.getNumOperands();
+
OpOperandVector result;
- result.reserve(getNumOutputs());
- llvm::copy_if(getOutputOperands(),
- std::back_inserter(result),
- [](OpOperand *opOperand) {
- return opOperand->get().getType().template isa<RankedTensorType>();
- });
- return result;
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the types of the subset of output operands that are of buffer type.
- }],
- /*retTy=*/"SmallVector<MemRefType>",
- /*methodName=*/"getOutputBufferTypes",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- SmallVector<MemRefType> result;
- result.reserve(getNumOutputs());
- llvm::transform(getOutputBufferOperands(),
- std::back_inserter(result),
- [](OpOperand *opOperands) {
- return opOperands->get().getType().cast<MemRefType>();
- });
+ result.reserve(numOperands - numOutputs);
+ for (int i = 0; i < start; ++i)
+ result.push_back(&$_op->getOpOperand(i));
+ for (int i = end; i < numOperands; ++i)
+ result.push_back(&$_op->getOpOperand(end + i));
+
return result;
}]
>,
InterfaceMethod<
- /*desc=*/[{
- Return the types of the subset of output operands that are of tensor type.
- }],
- /*retTy=*/"SmallVector<RankedTensorType>",
- /*methodName=*/"getOutputTensorTypes",
- /*args=*/(ins),
+ /*desc=*/[{ Return the `i`-th input operand. }],
+ /*retTy=*/"OpOperand*",
+ /*methodName=*/"getInputOperand",
+ /*args=*/(ins "int64_t":$i),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- SmallVector<RankedTensorType> result;
- result.reserve(getNumOutputs());
- llvm::transform(getOutputTensorOperands(),
- std::back_inserter(result),
- [](OpOperand *opOperands) {
- return opOperands->get().getType().cast<RankedTensorType>();
- });
- return result;
+ assert(i >= 0 && i < getNumInputs());
+ auto [start, end] = $_op.getOutputsPositionRange();
+ return &$_op->getOpOperand(i < start ? i : i + end - start) ;
}]
>,
//===------------------------------------------------------------------===//
// Input and Output arguments handling.
//===------------------------------------------------------------------===//
InterfaceMethod<
- /*desc=*/[{
- Return the range over input and output operands.
- }],
- /*retTy=*/"OpOperandVector",
- /*methodName=*/"getInputAndOutputOperands",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- int64_t numInputsAndOutputs = getNumInputsAndOutputs();
- OpOperandVector result;
- result.reserve(numInputsAndOutputs);
- llvm::transform(
- this->getOperation()->getOpOperands(),
- std::back_inserter(result),
- [](OpOperand &opOperand) { return &opOperand; });
- return result;
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return true if `opOperand` is an input tensor.
- }],
+ /*desc=*/"Return true if `opOperand` is an input.",
/*retTy=*/"bool",
- /*methodName=*/"isInputTensor",
+ /*methodName=*/"isInput",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- if (!opOperand->get().getType().template isa<RankedTensorType>())
- return false;
- if (opOperand->getOperandNumber() < $_op.getNumInputs())
- return true;
- return false;
+ auto [start, end] = $_op.getOutputsPositionRange();
+ auto operandNumber = opOperand->getOperandNumber();
+ return operandNumber < start || operandNumber >= end;
}]
>,
InterfaceMethod<
- /*desc=*/[{
- Return true if `opOperand` is an output tensor.
- }],
+ /*desc=*/"Return true if `opOperand` is an output.",
/*retTy=*/"bool",
- /*methodName=*/"isOutputTensor",
+ /*methodName=*/"isOutput",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- if (!opOperand->get().getType().template isa<RankedTensorType>())
- return false;
- if (opOperand->getOperandNumber() >= $_op.getNumInputs())
- return true;
- return false;
+ auto [start, end] = $_op.getOutputsPositionRange();
+ auto operandNumber = opOperand->getOperandNumber();
+ return operandNumber >= start && operandNumber < end;
}]
>,
InterfaceMethod<
- /*desc=*/[{
- Return true if the `opOperand` is a scalar value.
- }],
+ /*desc=*/"Return true if the `opOperand` is a scalar value.",
/*retTy=*/"bool",
/*methodName=*/"isScalar",
/*args=*/(ins "OpOperand*":$opOperand),
@@ -1263,35 +1068,33 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
}]
>,
InterfaceMethod<
- /*desc=*/[{
- Return the result tied to `opOperand`.
- }],
+ /*desc=*/"Return the result tied to `opOperand`.",
/*retTy=*/"OpResult",
/*methodName=*/"getTiedOpResult",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
- int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs();
+
+ auto [start, end] = $_op.getOutputsPositionRange();
+ int64_t resultIndex = opOperand->getOperandNumber() - start;
assert(resultIndex >= 0 &&
- resultIndex < this->getOperation()->getNumResults() );
- return this->getOperation()->getResult(resultIndex);
+ resultIndex < $_op->getNumResults() );
+ return $_op->getResult(resultIndex);
}]
>,
//===------------------------------------------------------------------===//
// Other interface methods.
//===------------------------------------------------------------------===//
InterfaceMethod<
- /*desc=*/[{
- Return whether the op has only MemRef input and outputs.
- }],
+ /*desc=*/"Return whether the op has only MemRef input and outputs.",
/*retTy=*/"bool",
/*methodName=*/"hasBufferSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return this->getOperation()->getNumResults() == 0 &&
- llvm::all_of(this->getOperation()->getOpOperands(),
+ return $_op->getNumResults() == 0 &&
+ llvm::all_of($_op->getOpOperands(),
[&](OpOperand &opOperand) {
return isScalar(&opOperand) ||
opOperand.get().getType().template isa<MemRefType>();
@@ -1299,15 +1102,13 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
}]
>,
InterfaceMethod<
- /*desc=*/[{
- Return whether the op has only RankedTensor input and outputs.
- }],
+ /*desc=*/"Return whether the op has only RankedTensor input and outputs.",
/*retTy=*/"bool",
/*methodName=*/"hasTensorSemantics",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return llvm::all_of(this->getOperation()->getOpOperands(),
+ return llvm::all_of($_op->getOpOperands(),
[&](OpOperand &opOperand) {
return isScalar(&opOperand) ||
opOperand.get().getType().template isa<RankedTensorType>();
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 2619ad1186408..3b06a59fcf032 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -215,6 +215,10 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
getRegionBuilder() {
return nullptr;
}
+ std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ int64_t getNumOperands = this->getNumOperands();
+ return {getNumOperands - getOutputs().size(), getNumOperands};
+ }
}];
let hasCanonicalizer = 1;
@@ -271,11 +275,10 @@ def MapOp : LinalgStructuredBase_Op<"map", [
}
// Implement functions necessary for DestinationStyleOpInterface.
- unsigned getNumInputs() {
- return this->getOperation()->getNumOperands() - getNumOutputs();
- };
- unsigned getNumOutputs() { return 1; };
- mlir::ValueRange getOutputs() { return getOperands().take_back(1); }
+ std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ int64_t getNumOperands = this->getNumOperands();
+ return {getNumOperands - 1, getNumOperands};
+ }
linalg::OpOperandVector getOpOperandsMatchingBBargs() {
return getInputOperands();
}
@@ -341,14 +344,14 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
}
// Implement functions necessary for DestinationStyleOpInterface.
- mlir::ValueRange getOutputs() { return getInits(); }
- unsigned getNumInputs() { return getInputs().size(); };
- unsigned getNumOutputs() { return getInits().size(); };
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
return nullptr;
}
+ std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ return {getInits().size(), getNumOperands()};
+ }
}];
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index bfb3313d1a21d..2fb5bc651de07 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -29,9 +29,9 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
SmallVector<Type, 8> argTypes;
SmallVector<Location, 8> argLocs;
- for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
- argTypes.push_back(getElementTypeOrSelf(opOperand->get().getType()));
- argLocs.push_back(opOperand->get().getLoc());
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ argTypes.push_back(getElementTypeOrSelf(opOperand.get().getType()));
+ argLocs.push_back(opOperand.get().getLoc());
}
ImplicitLocOpBuilder b(op->getLoc(), op->getContext());
diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index 5a6d6788a70a9..c4c7efb0b7c0f 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -166,6 +166,8 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
<< " and " << *dst.getOperation() << "\n");
if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
for (OpOperand *dstOpOperand : dst.getInputOperands()) {
+ if (!dstOpOperand->get().getType().isa<RankedTensorType>())
+ continue;
// Check if the operand is defined by the src.
auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
if (definingOp && definingOp == src)
@@ -188,23 +190,31 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
}
assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
"unhandled dependence tracking for mixed buffer/tensor operations");
- for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W
+ for (OpOperand *srcOpOperand : src.getOutputOperands()) { // W
// RAW graph
- for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R
+ for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R
+ if (!dstOpOperand->get().getType().isa<MemRefType>())
+ continue;
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias
addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand);
+ }
// WAW graph
- for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W
+ for (OpOperand *dstOpOperand : dst.getOutputOperands()) // W
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias
addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand);
}
- for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R
+ for (OpOperand *srcOpOperand : src.getInputOperands()) { // R
+ if (!srcOpOperand->get().getType().isa<MemRefType>())
+ continue;
// RAR graph
- for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R
+ for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R
+ if (!dstOpOperand->get().getType().isa<MemRefType>())
+ continue;
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias
addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand);
+ }
// WAR graph
- for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W
+ for (OpOperand *dstOpOperand : dst.getOutputOperands()) // W
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias
addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand);
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 9a62a40c42be0..88fd71c591ad6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -31,10 +31,10 @@ using namespace mlir::linalg;
bool linalg::detail::canOpOperandsBeDroppedImpl(
linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
SmallVector<AffineMap> indexingMaps;
- for (auto *opOperand : linalgOp.getInputAndOutputOperands()) {
- if (llvm::is_contained(droppedOperands, opOperand))
+ for (auto &opOperand : linalgOp->getOpOperands()) {
+ if (llvm::is_contained(droppedOperands, &opOperand))
continue;
- indexingMaps.push_back(linalgOp.getMatchingIndexingMap(opOperand));
+ indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
}
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}
@@ -491,9 +491,9 @@ static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source,
SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
Location loc) {
SmallVector<OpFoldResult> res;
- for (OpOperand *opOperand : getInputAndOutputOperands()) {
- for (int64_t i = 0, e = getRank(opOperand); i < e; ++i)
- res.push_back(createFoldedDimOp(b, loc, opOperand->get(), i));
+ for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+ for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
+ res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
}
return res;
}
@@ -501,8 +501,8 @@ SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
SmallVector<int64_t, 4> res;
assert(!hasDynamicShape() && "expected operands to have static shapes");
- for (OpOperand *opOperand : getInputAndOutputOperands())
- llvm::append_range(res, getShape(opOperand));
+ for (OpOperand &opOperand : getOperation()->getOpOperands())
+ llvm::append_range(res, getShape(&opOperand));
return res;
}
@@ -644,32 +644,32 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
// All input/output operands must be indexed.
if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
- linalgOp.getNumInputsAndOutputs())
+ linalgOp->getNumOperands())
return op->emitOpError("expected the number of indexing_map (")
<< linalgOp.getIndexingMapsArray().size()
<< ") to be equal to the number of input/output operands ("
- << linalgOp.getNumInputsAndOutputs() << ")";
+ << linalgOp->getNumOperands() << ")";
- for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
- AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
// Symbols disallowed.
if (indexingMap.getNumSymbols() != 0)
return op->emitOpError("unexpected symbols in indexing_map #")
- << opOperand->getOperandNumber();
+ << opOperand.getOperandNumber();
// Domain must be consistent.
unsigned numLoops = linalgOp.getNumLoops();
if (indexingMap.getNumDims() != numLoops)
return op->emitOpError("expected indexing_map #")
- << opOperand->getOperandNumber() << " to have " << numLoops
+ << opOperand.getOperandNumber() << " to have " << numLoops
<< " dim(s) to match the number of loops";
- int64_t rank = linalgOp.getRank(opOperand);
+ int64_t rank = linalgOp.getRank(&opOperand);
if (indexingMap.getNumResults() != rank)
return op->emitOpError("expected operand rank (")
<< rank << ") to match the result rank of indexing_map #"
- << opOperand->getOperandNumber() << " ("
+ << opOperand.getOperandNumber() << " ("
<< indexingMap.getNumResults() << ")";
}
@@ -688,13 +688,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
for (int64_t &range : endLoopRangeValues)
range -= 1;
- for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
- AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
SmallVector<int64_t, 4> startIndices =
indexingMap.compose(startLoopRangeValues);
SmallVector<int64_t, 4> endIndices =
indexingMap.compose(endLoopRangeValues);
- ArrayRef<int64_t> shape = linalgOp.getShape(opOperand);
+ ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
// Ignore dynamic dimension or the case that the dimension size is 0
if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
@@ -725,17 +725,16 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
if (inferredDimSize != shape[dim]) {
return op->emitOpError("inferred input/output operand #")
- << opOperand->getOperandNumber()
- << " has shape's dimension #" << dim << " to be "
- << inferredDimSize << ", but found " << shape[dim];
+ << opOperand.getOperandNumber() << " has shape's dimension #"
+ << dim << " to be " << inferredDimSize << ", but found "
+ << shape[dim];
}
} else {
if (inferredDimSize > shape[dim]) {
return op->emitOpError("inferred input/output operand #")
- << opOperand->getOperandNumber()
- << " has shape's dimension #" << dim
- << " to be greater than or equal to " << inferredDimSize
- << ", but found " << shape[dim];
+ << opOperand.getOperandNumber() << " has shape's dimension #"
+ << dim << " to be greater than or equal to "
+ << inferredDimSize << ", but found " << shape[dim];
}
}
}
@@ -777,6 +776,15 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) {
DestinationStyleOpInterface dstStyleOp =
cast<DestinationStyleOpInterface>(op);
+ SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
+ for (OpOperand *operand : dstStyleOp.getOutputOperands()) {
+ Type type = operand->get().getType();
+ if (type.isa<MemRefType>())
+ outputBufferOperands.push_back(operand);
+ if (type.isa<RankedTensorType>())
+ outputTensorOperands.push_back(operand);
+ }
+
// Expect at least one output operand.
// This means an op that constructs a tensor out of indices cannot be a
// LinalgOp at the moment. For now this will have to be a special op until we
@@ -788,23 +796,22 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) {
if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
return failure();
// Verify the number of results matches the number of output tensors.
- if (op->getNumResults() != dstStyleOp.getOutputTensorOperands().size())
+ if (op->getNumResults() != outputTensorOperands.size())
return op->emitOpError("expected the number of results (")
<< op->getNumResults()
<< ") to be equal to the number of output tensors ("
- << dstStyleOp.getOutputTensorOperands().size() << ")";
+ << outputTensorOperands.size() << ")";
// Simplifying assumption: either full tensor or full buffer mode.
// This allows simpler verification of output operands vs result types
// without premature tracking of which operand is what in mixed-mode.
// TODO: relax when mixed-mode needs to pass verification.
- if (!dstStyleOp.getOutputBufferOperands().empty() &&
- !dstStyleOp.getOutputTensorOperands().empty())
+ if (!outputBufferOperands.empty() && !outputTensorOperands.empty())
return op->emitOpError(
"expected output operands to all have tensor type or "
"all have buffer type");
- for (OpOperand *opOperand : dstStyleOp.getOutputTensorOperands()) {
+ for (OpOperand *opOperand : outputTensorOperands) {
OpResult result = dstStyleOp.getTiedOpResult(opOperand);
if (result.getType() != opOperand->get().getType())
return op->emitOpError("expected type of operand #")
@@ -813,6 +820,5 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) {
<< " to match type of corresponding result (" << result.getType()
<< ")";
}
-
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c2705a383550a..586d1985db449 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -767,7 +767,8 @@ void GenericOp::print(OpAsmPrinter &p) {
}
// Printing is shared with named ops, except for the region and attributes
- printCommonStructuredOpParts(p, getInputs(), getOutputs());
+ printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
+ SmallVector<Value>(getOutputOperands()));
genericAttrNames.push_back("operand_segment_sizes");
genericAttrNamesSet.insert(genericAttrNames.back());
@@ -835,15 +836,20 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
static void getGenericEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
- ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
- for (Value value : inputBuffers) {
- effects.emplace_back(MemoryEffects::Read::get(), value,
+ ValueRange results, OpOperandVector inputOperands,
+ OpOperandVector outputOperands) {
+ for (auto *operand : inputOperands) {
+ if (!operand->get().getType().isa<MemRefType>())
+ continue;
+ effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
SideEffects::DefaultResource::get());
}
- for (Value value : outputs) {
- effects.emplace_back(MemoryEffects::Read::get(), value,
+ for (auto *operand : outputOperands) {
+ if (!operand->get().getType().isa<MemRefType>())
+ continue;
+ effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), value,
+ effects.emplace_back(MemoryEffects::Write::get(), operand->get(),
SideEffects::DefaultResource::get());
}
}
@@ -851,10 +857,8 @@ static void getGenericEffectsImpl(
void GenericOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- SmallVector<Value> inputBuffers = getInputBufferOperands();
- SmallVector<Value> outputBuffers = getOutputBufferOperands();
- getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
- outputBuffers);
+ getGenericEffectsImpl(effects, getOperation()->getResults(),
+ getInputOperands(), getOutputOperands());
}
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
@@ -925,7 +929,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
// Check if there is any change to operands.
if (newInputOperands.size() + newOutputOperands.size() ==
- static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
+ genericOp->getNumOperands())
return failure();
// Create the new op with the body being empty.
@@ -977,35 +981,34 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
SmallVector<AffineMap> &newIndexingMaps) const {
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
- for (const auto &inputOpOperand :
- llvm::enumerate(genericOp.getInputOperands())) {
+ for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) {
+ OpOperand *inputOpOperand = en.value();
// Check if operand is dead and if dropping the indexing map makes the
// loops to shape computation invalid.
- if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) {
+ if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
// Add the current operands to the list of potentially droppable
// operands. If it cannot be dropped, this needs to be popped back.
- droppedOpOperands.push_back(inputOpOperand.value());
+ droppedOpOperands.push_back(inputOpOperand);
if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
continue;
droppedOpOperands.pop_back();
}
// Check if this operand is a duplicate.
- AffineMap indexingMap =
- genericOp.getMatchingIndexingMap(inputOpOperand.value());
+ AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
auto it = dedupedInputs.find(
- std::make_pair(inputOpOperand.value()->get(), indexingMap));
+ std::make_pair(inputOpOperand->get(), indexingMap));
if (it != dedupedInputs.end()) {
- origToNewPos[inputOpOperand.index()] = it->second;
- droppedOpOperands.push_back(inputOpOperand.value());
+ origToNewPos[en.index()] = it->second;
+ droppedOpOperands.push_back(inputOpOperand);
continue;
}
// This is a preserved argument.
- origToNewPos[inputOpOperand.index()] = newInputOperands.size();
- dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] =
+ origToNewPos[en.index()] = newInputOperands.size();
+ dedupedInputs[{inputOpOperand->get(), indexingMap}] =
newInputOperands.size();
- newInputOperands.push_back(inputOpOperand.value()->get());
+ newInputOperands.push_back(inputOpOperand->get());
newIndexingMaps.push_back(indexingMap);
}
return origToNewPos;
@@ -1026,12 +1029,10 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
// If the op doesnt have tensor semantics, keep all the outputs as
// preserved.
if (!genericOp.hasTensorSemantics()) {
- for (const auto &outputOpOperand :
- llvm::enumerate(genericOp.getOutputOperands())) {
- origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
- newOutputOperands.push_back(outputOpOperand.value()->get());
- newIndexingMaps.push_back(
- genericOp.getMatchingIndexingMap(outputOpOperand.value()));
+ for (const auto &en : llvm::enumerate(genericOp.getOutputOperands())) {
+ origToNewPos[en.index()] = newOutputOperands.size();
+ newOutputOperands.push_back(en.value()->get());
+ newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value()));
}
return origToNewPos;
}
@@ -1347,7 +1348,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
}
void MapOp::print(OpAsmPrinter &p) {
- printCommonStructuredOpParts(p, getInputs(), getOutputs());
+ printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
+ SmallVector<Value>(getOutputOperands()));
p.printOptionalAttrDict((*this)->getAttrs());
p << "(";
@@ -1380,7 +1382,7 @@ LogicalResult MapOp::verify() {
// The shape of each input must match the shape of the output.
auto outputShape =
- getOutputs().front().getType().cast<ShapedType>().getShape();
+ getOutputOperand(0)->get().getType().cast<ShapedType>().getShape();
for (Type inputArgType : TypeRange{getInputs()}) {
auto inputElemShape = inputArgType.cast<ShapedType>().getShape();
if (inputElemShape != outputShape) {
@@ -1409,10 +1411,8 @@ ArrayAttr MapOp::getIndexingMaps() {
void MapOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- SmallVector<Value> inputBuffers = getInputBufferOperands();
- SmallVector<Value> outputBuffers = getOutputBufferOperands();
- getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
- outputBuffers);
+ getGenericEffectsImpl(effects, getOperation()->getResults(),
+ getInputOperands(), getOutputOperands());
}
//===----------------------------------------------------------------------===//
@@ -1458,10 +1458,8 @@ ArrayAttr ReduceOp::getIndexingMaps() {
void ReduceOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- SmallVector<Value> inputBuffers = getInputBufferOperands();
- SmallVector<Value> outputBuffers = getOutputBufferOperands();
- getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
- outputBuffers);
+ getGenericEffectsImpl(effects, getOperation()->getResults(),
+ getInputOperands(), getOutputOperands());
}
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1500,7 +1498,8 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
}
void ReduceOp::print(OpAsmPrinter &p) {
- printCommonStructuredOpParts(p, getInputs(), getOutputs());
+ printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
+ SmallVector<Value>(getOutputOperands()));
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
@@ -1584,10 +1583,11 @@ LogicalResult ReduceOp::verify() {
}
// Check that the last block arguments match the element type of the outputs.
- for (auto [output, bbArg] : llvm::zip(
- getOutputs(), block->getArguments().take_back(getNumOutputs()))) {
+ for (auto [output, bbArg] :
+ llvm::zip(getOutputOperands(),
+ block->getArguments().take_back(getNumOutputs()))) {
auto outputElementType =
- output.getType().cast<ShapedType>().getElementType();
+ output->get().getType().cast<ShapedType>().getElementType();
if (outputElementType != bbArg.getType())
return emitOpError()
<< "output element type " << outputElementType
@@ -1751,14 +1751,14 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
- for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
+ for (OpOperand &opOperand : op->getOpOperands()) {
// Linalg "inputs" may be either tensor or memref type.
// tensor<0xelt_type> is a convention that may not always mean
// "0 iterations". Only erase in cases we see memref<...x0x...>.
- auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
+ auto mt = opOperand.get().getType().dyn_cast<MemRefType>();
if (!mt)
continue;
- if (llvm::is_contained(op.getShape(opOperand), 0)) {
+ if (llvm::is_contained(op.getShape(&opOperand), 0)) {
rewriter.eraseOp(op);
return success();
}
@@ -1774,10 +1774,10 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
PatternRewriter &rewriter) const override {
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
- llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
- if (opOperand->get().isa<BlockArgument>())
+ llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+ if (opOperand.get().isa<BlockArgument>())
return false;
- auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+ auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
if (!hasTensorCastOperand)
@@ -1788,18 +1788,17 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
SmallVector<Value, 4> newOperands;
newOperands.reserve(op->getNumOperands());
// Inputs may fold.
- for (OpOperand *opOperand : op.getInputOperands()) {
- auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+ for (auto *input : op.getInputOperands()) {
+ auto tensorCastOp = input->get().getDefiningOp<tensor::CastOp>();
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
? tensorCastOp.getSource()
- : opOperand->get());
+ : input->get());
}
// Init tensors may fold, in which case the resultType must also change.
- for (OpOperand *opOperand : op.getOutputOperands()) {
- auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+ for (auto *output : op.getOutputOperands()) {
+ auto tensorCastOp = output->get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
- newOperands.push_back(fold ? tensorCastOp.getOperand()
- : opOperand->get());
+ newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get());
newResultTypes.push_back(newOperands.back().getType());
}
// Clone op.
@@ -1858,8 +1857,8 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
Value newOperand =
rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
- SmallVector<Value> newOperands = linalgOp.getInputOperands();
- SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
+ SmallVector<Value> newOperands{linalgOp.getInputOperands()};
+ SmallVector<Value> outputOperands{linalgOp.getOutputOperands()};
outputOperands[resultNumber] = newOperand;
newOperands.append(outputOperands.begin(), outputOperands.end());
@@ -1882,14 +1881,14 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
/// For each of the operand in `operands` this function maps the static sizes of
/// dimensions to their affine dim expressions.
-static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands,
+static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
- for (OpOperand *opOperand : operands) {
- if (linalgOp.isScalar(opOperand))
+ for (OpOperand &opOperand : operands) {
+ if (linalgOp.isScalar(&opOperand))
continue;
- Value src = opOperand->get();
+ Value src = opOperand.get();
auto sourceType = src.getType().cast<RankedTensorType>();
- auto sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
+ auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
// Get the `sourceShape` of the `sourceType`. If the operand is a result of
// `tensor.cast` operation and source of the cast operation has a static
@@ -1932,7 +1931,7 @@ static void createNewOperandWithStaticSizes(
return;
auto sourceType = src.getType().cast<RankedTensorType>();
Type resultType = sourceType;
- if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) {
+ if (sourceType.hasStaticShape() && linalgOp.isOutput(opOperand)) {
resultTypes.push_back(resultType);
return;
}
@@ -1965,7 +1964,7 @@ static void createNewOperandWithStaticSizes(
unsigned index = opOperand->getOperandNumber();
newOperands[index] = newOperand;
}
- if (linalgOp.isOutputTensor(opOperand))
+ if (linalgOp.isOutput(opOperand))
resultTypes.push_back(resultType);
}
@@ -1992,8 +1991,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
// For each of the affine dim expression, check if the size is known. If
// known add that in the map.
- populateMap(linalgOp, linalgOp.getInputAndOutputOperands(),
- affineExprToSize);
+ populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
SmallVector<Value> newOperands;
SmallVector<Type> resultTypes;
@@ -2001,12 +1999,12 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
// `changeNeeded` is `false` if the operands of `linalgOp` require no
// change in their types.
bool changeNeeded = false;
- newOperands.reserve(linalgOp.getNumInputsAndOutputs());
+ newOperands.reserve(linalgOp->getNumOperands());
resultTypes.reserve(linalgOp.getNumOutputs());
// Iterate over all the operands and update the static sizes.
- for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
- createNewOperandWithStaticSizes(loc, rewriter, opOperand,
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
affineExprToSize, linalgOp, newOperands,
resultTypes, changeNeeded);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
index 2cf8a57f3fc83..383a9267cdd0d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
@@ -112,14 +112,14 @@ struct BubbleUpExtractSliceOpPattern
tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
}
- SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+ SmallVector<Value> valuesToTile = linalgOp->getOperands();
SmallVector<Value> tiledOperands =
makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
tileOffsets, tileSizes, sizeBounds,
/*omitPartialTileCheck=*/true);
SmallVector<Type, 4> resultTensorTypes;
- for (OpOperand *opOperand : linalgOp.getOutputTensorOperands())
+ for (OpOperand *opOperand : linalgOp.getOutputOperands())
resultTensorTypes.push_back(
tiledOperands[opOperand->getOperandNumber()].getType());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index abc430faddefe..bb380046a53ba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -118,7 +118,7 @@ struct LinalgOpInterface
auto genericOp = cast<linalg::DestinationStyleOpInterface>(op);
// The i-th "out" tensor may alias with the i-th OpResult.
- if (genericOp.isOutputTensor(&opOperand))
+ if (genericOp.isOutput(&opOperand))
return {genericOp.getTiedOpResult(&opOperand)};
return {};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 58a54bae2239c..a21e0fc769d68 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -68,17 +68,17 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
if (!outputType || !outputType.hasStaticShape())
return failure();
- if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
- return operand->get().getType().isa<ShapedType>();
+ if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
+ return input.getType().isa<ShapedType>();
}))
return failure();
// Make sure all element types are the same.
- auto getOperandElementType = [](OpOperand *operand) {
- return operand->get().getType().cast<ShapedType>().getElementType();
+ auto getOperandElementType = [](Value value) {
+ return value.getType().cast<ShapedType>().getElementType();
};
- if (!llvm::all_equal(llvm::map_range(genericOp.getInputAndOutputOperands(),
- getOperandElementType)))
+ if (!llvm::all_equal(
+ llvm::map_range(genericOp->getOperands(), getOperandElementType)))
return failure();
// We can only handle the case where we have int/float elements.
@@ -114,15 +114,15 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
// All inputs should be constants.
int numInputs = genericOp.getNumInputs();
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
- for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) {
- if (!matchPattern(operand.value()->get(),
- m_Constant(&inputValues[operand.index()])))
+ for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) {
+ if (!matchPattern(en.value()->get(),
+ m_Constant(&inputValues[en.index()])))
return failure();
}
// Identified this as a potential candidate for folding. Now check the
// policy to see whether we are allowed to proceed.
- for (auto *operand : genericOp.getInputOperands()) {
+ for (OpOperand *operand : genericOp.getInputOperands()) {
if (!controlFn(operand))
return failure();
}
@@ -171,8 +171,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
APIntOrFloatArray computeFnInputs;
auto inputShapes = llvm::to_vector<4>(
- llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
- return operand->get().getType().cast<ShapedType>().getShape();
+ llvm::map_range(genericOp.getInputs(), [](Value value) {
+ return value.getType().cast<ShapedType>().getShape();
}));
// Given a `linearIndex`, remap it to a linear index to access linalg op
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index cebc978333b22..327e8bec9a7e7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -194,7 +194,7 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
}
/// Create the peeled generic op with an empty body.
- SmallVector<Value> outsOperands = genericOp.getOutputOperands();
+ SmallVector<Value> outsOperands = genericOp.getOutputs();
outsOperands.append(newInitValues.begin(), newInitValues.end());
SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
resultTypes.append(newResultTypes.begin(), newResultTypes.end());
@@ -212,9 +212,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
PatternRewriter &rewriter) const {
/// Append all results from the peeledGenericOps as `ins` operand for the
/// residual generic op.
- SmallVector<Value> residualGenericOpOperands = llvm::to_vector(
- llvm::map_range(genericOp.getInputOperands(),
- [](OpOperand *operand) { return operand->get(); }));
+ SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
unsigned origNumResults = genericOp.getNumResults();
unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
SmallVector<Value> extraIns;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index baef90c99de4c..acc0126f198ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -55,10 +55,9 @@ bool canBeDetensored(TensorType tensorType) {
bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
return genericOp &&
- llvm::all_of(
- genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
- return !typeConverter.isLegal(opOperand->get().getType());
- });
+ llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) {
+ return !typeConverter.isLegal(opOperand.get().getType());
+ });
}
/// A conversion patttern for detensoring `linalg.generic` ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 361c85a96dc79..2fcb2acf810ed 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -377,21 +377,21 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
SmallVector<ArrayAttr> reassociationMaps;
SmallVector<Type> newInputOutputTypes;
bool doCanonicalization = false;
- for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
- auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
+ for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context);
if (replacementInfo) {
reassociationMaps.push_back(replacementInfo->reassociation);
newIndexingMaps.push_back(replacementInfo->indexMap);
newInputOutputTypes.push_back(replacementInfo->type);
doCanonicalization |=
- replacementInfo->type != opOperand->get().getType();
+ replacementInfo->type != opOperand.get().getType();
} else {
// If replaceUnitExtents cannot handle this case, maintain the same
// type, indexing map, and create a set of mappings representing an
// identity matrix.
- newInputOutputTypes.push_back(opOperand->get().getType());
- newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(opOperand));
- int64_t origRank = genericOp.getRank(opOperand);
+ newInputOutputTypes.push_back(opOperand.get().getType());
+ newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand));
+ int64_t origRank = genericOp.getRank(&opOperand);
auto maps = llvm::to_vector<8>(llvm::map_range(
llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
return AffineMapAttr::get(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05dce4c40272b..80cef16724939 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -90,7 +90,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
// Only allow fusing the producer of an input operand for now.
// TODO: allow fusing the producer of an output operand.
- if (!consumer.isInputTensor(fusedOperand))
+ if (!consumer.isInput(fusedOperand))
return false;
// Get the consumer index map. The number of results of the consumer index
@@ -179,7 +179,7 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp,
}
}
// TODO: allow fusing the producer of an output operand.
- assert(consumer.isInputTensor(fusedOperand) &&
+ assert(consumer.isInput(fusedOperand) &&
"expected producer of input operand");
// 3. Consumer input operands up to consumerIdx (exclusive).
for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
@@ -267,7 +267,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
- assert(consumer.isInputTensor(fusedOperand) &&
+ assert(consumer.isInput(fusedOperand) &&
"expected producer of input operand");
// Compute the fused operands list and indexing maps.
@@ -278,13 +278,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
fusedOutputOperands.reserve(producer.getNumOutputs() +
consumer.getNumOutputs());
fusedResultTypes.reserve(producer.getNumOutputs() + consumer.getNumOutputs());
- fusedIndexMaps.reserve(producer.getNumInputsAndOutputs() +
- consumer.getNumInputsAndOutputs());
+ fusedIndexMaps.reserve(producer->getNumOperands() +
+ consumer->getNumOperands());
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
- SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
- SmallVector<OpOperand *>::iterator it =
- llvm::find(consumerInputs, fusedOperand);
+ auto consumerInputs = consumer.getInputOperands();
+ auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
+ return operand == fusedOperand;
+ });
assert(it != consumerInputs.end() && "expected to find the consumer operand");
for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
fusedInputOperands.push_back(opOperand->get());
@@ -373,13 +374,13 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
- for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
- if (!areElementwiseOpsFusable(opOperand))
+ for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ if (!areElementwiseOpsFusable(&opOperand))
continue;
- if (!controlFn(opOperand))
+ if (!controlFn(&opOperand))
continue;
- FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, opOperand);
+ FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, &opOperand);
if (succeeded(fusedOp)) {
auto replacements =
fusedOp.value()->getResults().take_back(genericOp.getNumResults());
@@ -727,9 +728,9 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
: collapsingReshapeOp.getSrc());
continue;
}
- if (genericOp.isInputTensor(opOperand)) {
+ if (auto opOperandType =
+ opOperand->get().getType().dyn_cast<RankedTensorType>()) {
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
- auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
RankedTensorType expandedOperandType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
if (expandedOperandType != opOperand->get().getType()) {
@@ -833,7 +834,7 @@ class FoldWithProducerReshapeOpByExpansion
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
+ for (OpOperand *opOperand : genericOp.getInputOperands()) {
tensor::CollapseShapeOp reshapeOp =
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
@@ -1494,17 +1495,17 @@ class FoldWithProducerReshapeOpByCollapsing
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
+ for (OpOperand &opOperand : genericOp->getOpOperands()) {
tensor::ExpandShapeOp reshapeOp =
- opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
+ opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
if (!reshapeOp)
continue;
SmallVector<ReassociationIndices> collapsableIterationDims =
- getCollapsableIterationSpaceDims(genericOp, opOperand,
+ getCollapsableIterationSpaceDims(genericOp, &opOperand,
reshapeOp.getReassociationIndices());
if (collapsableIterationDims.empty() ||
- !controlFoldingReshapes(opOperand)) {
+ !controlFoldingReshapes(&opOperand)) {
continue;
}
@@ -1614,7 +1615,7 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
SmallVector<AffineMap> fusedIndexMaps;
SmallVector<Value> fusedOperands;
SmallVector<Location> fusedLocs{genericOp.getLoc()};
- fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
+ fusedIndexMaps.reserve(genericOp->getNumOperands());
fusedOperands.reserve(genericOp.getNumInputs());
fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
for (OpOperand *inputOperand : genericOp.getInputOperands()) {
@@ -1640,7 +1641,7 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
Value scalarConstant = rewriter.create<arith::ConstantOp>(
def->getLoc(), constantAttr, constantAttr.getType());
- SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+ SmallVector<Value> outputOperands = genericOp.getOutputs();
auto fusedOp = rewriter.create<GenericOp>(
rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
/*inputs=*/fusedOperands,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 5738d51373493..5c2c987710936 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -68,7 +68,7 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
bool fromSubViewOpOnly = false) {
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
- for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
+ for (OpOperand &opOperand : op->getOpOperands()) {
// The method `getRangeFromOperandShape` requires using SubViewOp or
// ExtractSliceOps. If the value isn't defined from there continue.
// todo: The method should be adapted to get the values from
@@ -77,12 +77,12 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
// `std` dialect and add the method to `ViewInterface`.
if (fromSubViewOpOnly &&
!isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>(
- opOperand->get().getDefiningOp()))
+ opOperand.get().getDefiningOp()))
continue;
- AffineMap map = op.getMatchingIndexingMap(opOperand);
+ AffineMap map = op.getMatchingIndexingMap(&opOperand);
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
- << opOperand->getOperandNumber() << "\n");
+ << opOperand.getOperandNumber() << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "getShapeDefiningLoopRange map: " << map << "\n");
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
@@ -94,8 +94,8 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
<< loopDepth << "\n");
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
- << opOperand->get() << "\n");
- return ShapeDimension{opOperand->get(),
+ << opOperand.get() << "\n");
+ return ShapeDimension{opOperand.get(),
static_cast<unsigned>(en.index())};
}
}
@@ -104,7 +104,7 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
}
static SmallVector<Value> getTiledOperands(LinalgOp producer) {
- return producer.getInputAndOutputOperands();
+ return producer->getOperands();
}
/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
@@ -137,7 +137,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
}
SmallVector<Value, 8> clonedShapes;
- clonedShapes.reserve(producer.getNumInputsAndOutputs());
+ clonedShapes.reserve(producer->getNumOperands());
// Compute subranges for all tensor input/output operands.
clonedShapes.append(makeTiledShapes(
@@ -150,15 +150,18 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
// fully dynamic at construction time.
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(producer->getNumResults());
- for (RankedTensorType t : producer.getOutputTensorTypes()) {
- unsigned rank = t.getRank();
+ for (OpOperand *operand : producer.getOutputOperands()) {
+ auto tensorType = operand->get().getType().dyn_cast<RankedTensorType>();
+ if (!tensorType)
+ continue;
+ unsigned rank = tensorType.getRank();
SmallVector<int64_t, 4> staticOffsetsVector(
rank, ShapedType::kDynamicStrideOrOffset);
SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
SmallVector<int64_t, 4> staticStridesVector(
rank, ShapedType::kDynamicStrideOrOffset);
resultTypes.push_back(tensor::ExtractSliceOp::inferResultType(
- t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
+ tensorType, staticOffsetsVector, staticSizesVector,
staticStridesVector));
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 2451c79a35052..1dd6c35723b9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -161,7 +161,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
}
erase_value(tileIvs, OpFoldResult());
- SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
+ SmallVector<Value> tiledOperands = producerOp->getOperands();
tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
tileSizes, producerLoopBounds,
/**omitPartialTileCheck=*/false);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index d656e928c5312..ea6ce398b1a77 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -50,19 +50,19 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
if (failed(generalizeNamedOpPrecondition(linalgOp)))
return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
- SmallVector<Value> inputOperands = linalgOp.getInputOperands();
- SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
+ SmallVector<Value> inputs = linalgOp.getInputOperands();
+ SmallVector<Value> outputs = linalgOp.getOutputOperands();
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray();
- SmallVector<RankedTensorType> resultTypes = linalgOp.getOutputTensorTypes();
- SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
+ SmallVector<Type> resultTypes = linalgOp.hasTensorSemantics()
+ ? TypeRange(ValueRange(outputs))
+ : TypeRange{};
// All named ops have a region attached that can be inlined.
assert(linalgOp->getNumRegions() == 1 &&
"expect named op to have one region attached");
- GenericOp genericOp =
- rewriter.create<GenericOp>(linalgOp.getLoc(), types, inputOperands,
- outputOperands, indexingMaps, iterators);
+ GenericOp genericOp = rewriter.create<GenericOp>(
+ linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators);
rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
genericOp.getRegion().begin());
rewriter.replaceOp(linalgOp, genericOp->getResults());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 2ca15dc799413..7515e3006b94d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -111,7 +111,7 @@ struct HoistingAnalysis {
static bool isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padOp) {
for (OpOperand &use : padOp.getResult().getUses()) {
auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
- if (!linalgUser || !linalgUser.isInputTensor(&use)) {
+ if (!linalgUser || !linalgUser.isInput(&use)) {
LLVM_DEBUG(DBGS() << "Found a use of " << *(padOp)
<< "\nthat is not an input tensor of a LinalgOp, "
<< "cannot hoist\n"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index 04e94b1014e49..4ea889d94e522 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -43,7 +43,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
SmallVector<Value> newOperands;
for (OpOperand *opOperand : genericOp.getInputOperands()) {
AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
- if (genericOp.isInputTensor(opOperand) && map.isConstant()) {
+ if (genericOp.isInput(opOperand) && map.isConstant()) {
scalarOperands.emplace_back(opOperand->getOperandNumber());
} else {
newIndexingMaps.emplace_back(map);
@@ -58,7 +58,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
newIndexingMaps.emplace_back(genericOp.getMatchingIndexingMap(opOperand));
Location loc = genericOp->getLoc();
- SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+ SmallVector<Value> outputOperands = genericOp.getOutputs();
auto newOp = rewriter.create<GenericOp>(
loc, genericOp->getResultTypes(), newOperands, outputOperands,
newIndexingMaps, genericOp.getIteratorTypesArray());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index 8641e1106310e..a74538767d76a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -67,8 +67,8 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
// 2. Compute the interchanged indexing maps.
SmallVector<AffineMap> newIndexingMaps;
- for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
- AffineMap m = genericOp.getMatchingIndexingMap(opOperand);
+ for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ AffineMap m = genericOp.getMatchingIndexingMap(&opOperand);
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
newIndexingMaps.push_back(m);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 3052a4db29464..4fc914905fd7c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -131,7 +131,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
SmallVector<Value> indexedValues;
- indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
+ indexedValues.reserve(linalgOp->getNumOperands());
auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end());
@@ -161,7 +161,9 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
// 3. Emit store.
SmallVector<SmallVector<Value>, 8> indexing;
SmallVector<Value> outputBuffers;
- for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) {
+ for (OpOperand *outputOperand : linalgOp.getOutputOperands()) {
+ if (!outputOperand->get().getType().isa<MemRefType>())
+ continue;
indexing.push_back(makeCanonicalAffineApplies(
b, loc, linalgOp.getMatchingIndexingMap(outputOperand),
allIvsPlusDims));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 17d74fa6dda9c..0995b01092dda 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -145,15 +145,15 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand");
auto vUseFullTileBuffers =
options.useFullTileBuffers.value_or(llvm::SmallBitVector());
- vUseFullTileBuffers.resize(linalgOp.getNumInputsAndOutputs(),
+ vUseFullTileBuffers.resize(linalgOp->getNumOperands(),
options.useFullTileBuffersDefault);
- for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
- int64_t operandNumber = opOperand->getOperandNumber();
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ int64_t operandNumber = opOperand.getOperandNumber();
if (options.operandsToPromote &&
!options.operandsToPromote->count(operandNumber))
continue;
- Operation *op = opOperand->get().getDefiningOp();
+ Operation *op = opOperand.get().getDefiningOp();
if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
subViews[operandNumber] = sv;
useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
@@ -326,13 +326,13 @@ promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
// operands are not views. This is to support cases such as FillOp taking
// extra scalars etc. Keep a reference to output buffers;
SmallVector<Value, 8> opViews;
- opViews.reserve(op.getNumInputsAndOutputs());
+ opViews.reserve(op->getNumOperands());
SmallVector<std::pair<Value, Value>, 8> writebackViews;
writebackViews.reserve(promotedBuffersAndViews->size());
- for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
- int64_t operandNumber = opOperand->getOperandNumber();
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ int64_t operandNumber = opOperand.getOperandNumber();
if (options.subViews.count(operandNumber) != 0) {
- if (options.useFullTileBuffers[opOperand->get()])
+ if (options.useFullTileBuffers[opOperand.get()])
opViews.push_back(
(*promotedBuffersAndViews)[operandNumber].fullLocalView);
else
@@ -340,10 +340,10 @@ promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
(*promotedBuffersAndViews)[operandNumber].partialLocalView);
if (operandNumber >= op.getNumInputs())
writebackViews.emplace_back(std::make_pair(
- opOperand->get(),
+ opOperand.get(),
(*promotedBuffersAndViews)[operandNumber].partialLocalView));
} else {
- opViews.push_back(opOperand->get());
+ opViews.push_back(opOperand.get());
}
}
op->setOperands(0, opViews.size(), opViews);
@@ -371,12 +371,12 @@ mlir::linalg::promoteSubviewsPrecondition(Operation *op,
if (!linalgOp || !linalgOp.hasBufferSemantics())
return failure();
// Check that at least one of the requested operands is indeed a subview.
- for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
auto sv =
- isa_and_nonnull<memref::SubViewOp>(opOperand->get().getDefiningOp());
+ isa_and_nonnull<memref::SubViewOp>(opOperand.get().getDefiningOp());
if (sv) {
if (!options.operandsToPromote ||
- options.operandsToPromote->count(opOperand->getOperandNumber()))
+ options.operandsToPromote->count(opOperand.getOperandNumber()))
return success();
}
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 7df65c823a2fa..92d04c1cca5ba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -214,7 +214,6 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
// from the previous op.
unsigned intermRank = newOutputShape.size();
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
- SmallVector<Value> outputOperands = op.getOutputOperands();
SmallVector<StringRef> reductionIteratorTypes;
SmallVector<AffineExpr> exprs;
for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
@@ -230,7 +229,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
auto reduction = b.create<GenericOp>(
loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
- outputOperands, reductionMaps, reductionIteratorTypes,
+ SmallVector<Value>{op.getOutputOperands()}, reductionMaps,
+ reductionIteratorTypes,
[reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
Operation *clonedReductionOp = b.clone(*reductionOp);
clonedReductionOp->setOperand(0, inputs[0]);
@@ -341,8 +341,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
SmallVector<Operation *> emptyOrAllocTensorOps;
SmallVector<linalg::FillOp> fillOps;
fillOps.reserve(op.getNumOutputs());
- for (auto it : llvm::zip(op.getOutputs(), neutralElements)) {
- Value rankedTensor = std::get<0>(it);
+ for (auto it : llvm::zip(op.getOutputOperands(), neutralElements)) {
+ Value rankedTensor = std::get<0>(it)->get();
auto t = rankedTensor.getType().cast<RankedTensorType>();
RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
reductionDimSize / splitFactor, insertSplitDimension);
@@ -366,7 +366,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// Step 2. Reindex / expand indexing maps.
// Reindex existing input indexings: k -> k * splitFactor + k'.
SmallVector<AffineMap> newMaps;
- newMaps.reserve(op.getNumInputsAndOutputs() + 1);
+ newMaps.reserve(op->getNumOperands() + 1);
for (OpOperand *o : op.getInputOperands())
newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
// Provision a new indexing for the shape-only tensor.
@@ -384,7 +384,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// Step 3. Handle operands.
// Compute the new input tensors.
- auto newInputs = llvm::to_vector<4>(op.getInputs());
+ SmallVector<Value> newInputs(op.getInputOperands());
// Add a single shape-only tensor to carry the dimensions without resorting to
// more complex inversions.
newInputs.push_back(b.create<tensor::EmptyOp>(
@@ -413,10 +413,10 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// TODO: all results can be handled in a single GenericOp, when
// multi-reduction support is available.
SmallVector<LinalgOp> results;
- for (auto it :
- llvm::zip(genericOp->getResults(), op.getOutputs(), combinerOps)) {
+ for (auto it : llvm::zip(genericOp->getResults(), op.getOutputOperands(),
+ combinerOps)) {
Value reindexedOutput = std::get<0>(it);
- Value originalOutput = std::get<1>(it);
+ Value originalOutput = std::get<1>(it)->get();
auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
Operation *combinerOp = std::get<2>(it);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index dd04d00bee523..b66a7189b3501 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -503,7 +503,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
// Tile the `operandValuesToUse` that either match the `op` operands
// themselves or the tile loop arguments forwarding them.
assert(operandValuesToUse.size() ==
- static_cast<size_t>(op.getNumInputsAndOutputs()) &&
+ static_cast<size_t>(op->getNumOperands()) &&
"expect the number of operands and inputs and outputs to match");
SmallVector<Value> valuesToTile = operandValuesToUse;
SmallVector<OpFoldResult> sizeBounds =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 66d55dcf5c713..d88b2c56599a4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -125,14 +125,12 @@ struct LinalgOpTilingInterface
// specified could lead to out of bounds accesses.
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
- SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+ SmallVector<Value> valuesToTile = linalgOp->getOperands();
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
- SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
- linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
- return tiledOperands[opOperand->getOperandNumber()].getType();
- }));
+ SmallVector<Type> resultTensorTypes =
+ getTensorOutputTypes(linalgOp, tiledOperands);
Operation *tiledOp =
linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
@@ -222,23 +220,23 @@ struct LinalgOpTilingInterface
return op->emitOpError("expected operation to have buffer semantics");
SmallVector<Value> indexedValues;
- indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
+ indexedValues.reserve(linalgOp->getNumOperands());
Location linalgOpLoc = op->getLoc();
/// Load the data corresponding to the block arguments that
/// represent input operands.
- for (OpOperand *operand : linalgOp.getInputAndOutputOperands()) {
- if (!linalgOp.payloadUsesValueFromOperand(operand)) {
+ for (OpOperand &operand : linalgOp->getOpOperands()) {
+ if (!linalgOp.payloadUsesValueFromOperand(&operand)) {
indexedValues.push_back(nullptr);
continue;
}
- if (linalgOp.isScalar(operand)) {
- indexedValues.push_back(operand->get());
+ if (linalgOp.isScalar(&operand)) {
+ indexedValues.push_back(operand.get());
continue;
}
SmallVector<Value> indices = getIndicesForAccess(
- builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(operand), ivs);
+ builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
Value load =
- builder.create<memref::LoadOp>(linalgOpLoc, operand->get(), indices);
+ builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices);
indexedValues.push_back(load);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8eb41c5d88b42..eee454b9aec0e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -203,10 +203,10 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
b.setInsertionPointAfter(opToPad);
// Make a copy of the shaped operands and update it.
SmallVector<Value> newOperands;
- newOperands.reserve(opToPad.getNumInputsAndOutputs());
- for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
+ newOperands.reserve(opToPad->getNumOperands());
+ for (OpOperand &opOperand : opToPad->getOpOperands()) {
FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
- b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
+ b, opToPad, &opOperand, paddingDimensions, paddingValues, packPaddings);
// Exit if `paddingDimensions` cannot be bounded statically.
if (failed(paddedOperand))
return failure();
@@ -327,15 +327,15 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
// Hoist the padding.
for (const auto &en : enumerate(options.hoistPaddings)) {
- if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs())
+ if (static_cast<int64_t>(en.index()) >= paddedOp->getNumOperands())
break;
- OpOperand *opOperand = &paddedOp->getOpOperand(en.index());
- auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>();
+ OpOperand &opOperand = paddedOp->getOpOperand(en.index());
+ auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
if (!padOp || en.value() == 0)
continue;
// Fail hoisting if the operand shape is not fully static.
- if (llvm::any_of(paddedOp.getShape(opOperand), ShapedType::isDynamic))
+ if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic))
return failure();
tensor::PadOp hoistedOp;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5623a16fb2613..2b70155b24887 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -459,35 +459,35 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
// 3. Turn all BBArgs into vector.transfer_read / load.
Location loc = linalgOp.getLoc();
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
- for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
- BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber());
- if (linalgOp.isScalar(opOperand)) {
- bvm.map(bbarg, opOperand->get());
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ BlockArgument bbarg = block->getArgument(opOperand.getOperandNumber());
+ if (linalgOp.isScalar(&opOperand)) {
+ bvm.map(bbarg, opOperand.get());
continue;
}
VectorType readType;
AffineMap map;
// TODO: can we keep this simplification?
- // if (linalgOp.getShape(opOperand).empty()) {
+ // if (linalgOp.getShape(&opOperand).empty()) {
// readType = VectorType::get({}, bbarg.getType());
// } else {
- if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
+ if (opOperand.getOperandNumber() < linalgOp.getNumInputs()) {
map = inverseAndBroadcastProjectedPermutation(
- linalgOp.getMatchingIndexingMap(opOperand));
+ linalgOp.getMatchingIndexingMap(&opOperand));
readType = VectorType::get(commonVectorShape,
- getElementTypeOrSelf(opOperand->get()));
+ getElementTypeOrSelf(opOperand.get()));
} else {
map = inversePermutation(
- reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand)));
- readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
- getElementTypeOrSelf(opOperand->get()));
+ reindexIndexingMap(linalgOp.getMatchingIndexingMap(&opOperand)));
+ readType = VectorType::get(map.compose(linalgOp.getShape(&opOperand)),
+ getElementTypeOrSelf(opOperand.get()));
}
// }
- auto shape = linalgOp.getShape(opOperand);
+ auto shape = linalgOp.getShape(&opOperand);
SmallVector<Value> indices(shape.size(), zero);
Value readValue = b.create<vector::TransferReadOp>(
- loc, readType, opOperand->get(), indices, map);
+ loc, readType, opOperand.get(), indices, map);
// Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
if (readValue.getType().cast<VectorType>().getRank() == 0)
@@ -495,7 +495,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
bvm.map(bbarg, readValue);
- bvm.map(opOperand->get(), readValue);
+ bvm.map(opOperand.get(), readValue);
}
SmallVector<CustomVectorizationHook> hooks;
@@ -1342,9 +1342,9 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
// Determine whether `linalgOp` can be generated with this generator
if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
return;
- lhsShaped = linalgOp.getInputs()[0];
- rhsShaped = linalgOp.getInputs()[1];
- resShaped = linalgOp.getOutputs()[0];
+ lhsShaped = linalgOp.getInputOperand(0)->get();
+ rhsShaped = linalgOp.getInputOperand(1)->get();
+ resShaped = linalgOp.getOutputOperand(0)->get();
lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
resShapedType = resShaped.getType().dyn_cast<ShapedType>();
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 999034b4e36b0..119c3db644c55 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -490,17 +490,18 @@ void GenerateLoopNest<scf::ForOp>::doit(
assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
"expected as many entries for proc info as number of loops, even if "
"they are null entries");
- SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
+ SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
+ ? SmallVector<Value>{}
+ : linalgOp.getOutputOperands();
SmallVector<Value, 4> lbs, ubs, steps;
unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
LoopNest loopNest = mlir::scf::buildLoopNest(
b, loc, lbs, ubs, steps, iterArgInitValues,
[&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
- assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() &&
+ assert(iterArgs.size() == iterArgInitValues.size() &&
"expect the number of output tensors and iter args to match");
- SmallVector<Value> operandValuesToUse =
- linalgOp.getInputAndOutputOperands();
+ SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
if (!iterArgs.empty()) {
operandValuesToUse = linalgOp.getInputOperands();
operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
@@ -530,7 +531,9 @@ void GenerateLoopNest<AffineForOp>::doit(
ValueRange)>
bodyBuilderFn,
ArrayRef<linalg::ProcInfo> /*procInfo*/) {
- SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
+ SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
+ ? SmallVector<Value>{}
+ : linalgOp.getOutputOperands();
assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
SmallVector<Value, 4> lbs, ubs, steps;
unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
@@ -546,9 +549,8 @@ void GenerateLoopNest<AffineForOp>::doit(
mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
[&](OpBuilder &b, Location loc, ValueRange ivs) {
- SmallVector<Value> operandValuesToUse =
- linalgOp.getInputAndOutputOperands();
- bodyBuilderFn(b, loc, ivs, operandValuesToUse);
+ bodyBuilderFn(b, loc, ivs,
+ linalgOp->getOperands());
});
}
@@ -695,7 +697,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
ValueRange)>
bodyBuilderFn,
ArrayRef<linalg::ProcInfo> procInfo) {
- SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
+ SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
+ ? SmallVector<Value>{}
+ : linalgOp.getOutputOperands();
assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
// This function may be passed more iterator types than ranges.
assert(iteratorTypes.size() >= loopRanges.size() &&
@@ -725,9 +729,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
generateParallelLoopNest(
b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
[&](OpBuilder &b, Location loc, ValueRange ivs) {
- SmallVector<Value> operandValuesToUse =
- linalgOp.getInputAndOutputOperands();
- bodyBuilderFn(b, loc, ivs, operandValuesToUse);
+ bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
},
ivs);
@@ -905,10 +907,10 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
}
SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
- // TODO: use an interface/adaptor to avoid leaking position in
- // `tiledOperands`.
+ if (op.hasBufferSemantics())
+ return {};
return llvm::to_vector(
- llvm::map_range(op.getOutputTensorOperands(), [&](OpOperand *opOperand) {
+ llvm::map_range(op.getOutputOperands(), [&](OpOperand *opOperand) {
return operands[opOperand->getOperandNumber()].getType();
}));
}
@@ -916,11 +918,13 @@ SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
LinalgOp op, ValueRange operands,
ValueRange results) {
+ if (op.hasBufferSemantics())
+ return {};
SmallVector<Value> tensorResults;
tensorResults.reserve(results.size());
// Insert a insert_slice for each output tensor.
unsigned resultIdx = 0;
- for (OpOperand *opOperand : op.getOutputTensorOperands()) {
+ for (OpOperand *opOperand : op.getOutputOperands()) {
// TODO: use an interface/adaptor to avoid leaking position in
// `tiledOperands`.
Value outputTensor = operands[opOperand->getOperandNumber()];
@@ -958,23 +962,26 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
computeTileSizes(builder, loc, tileSizes, sizeBounds);
assert(static_cast<int64_t>(valuesToTile.size()) ==
- linalgOp.getNumInputsAndOutputs() &&
+ linalgOp->getNumOperands() &&
"expected one value to tile for every operand");
SmallVector<Optional<SliceParameters>> allSliceParams;
allSliceParams.reserve(valuesToTile.size());
- for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
- Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ Value shapedOp = valuesToTile[opOperand.getOperandNumber()];
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
- AffineMap map = linalgOp.getMatchingIndexingMap(opOperand);
+ AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
// Use `opOperand` as is if it is not tiled and not an output tensor. Having
// an extract/insert slice pair for all output tensors simplifies follow up
// transformations such as padding and bufferization since the
// extract/insert slice pairs make the accessed iteration argument
// subdomains explicit.
- if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) {
+
+ Type operandType = opOperand.get().getType();
+ if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() &&
+ linalgOp.isOutput(&opOperand))) {
allSliceParams.push_back(llvm::None);
- LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: "
- << opOperand->get().getType() << "\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << ": not tiled: use shape: " << operandType << "\n");
continue;
}
LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 2458dabef56c9..73f428af29066 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -105,8 +105,7 @@ static bool isZeroYield(GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
if (arg.getOwner()->getParentOp() == op) {
- OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()];
- return isZeroValue(t->get());
+ return isZeroValue(op->getOperand(arg.getArgNumber()));
}
}
return isZeroValue(yieldOp.getOperand(0));
@@ -242,8 +241,8 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
return failure();
// Modify operand structure of producer and consumer.
Location loc = prod.getLoc();
- SmallVector<Value> inputOps = prod.getInputOperands();
- SmallVector<Value> outputOps = op.getOutputOperands();
+ SmallVector<Value> inputOps = prod.getInputs();
+ SmallVector<Value> outputOps = op.getOutputs();
SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
inputOps.push_back(op.getInputOperand(1 - other)->get());
fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 1418ed4da4a4a..e5127232f9f62 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -194,14 +194,14 @@ static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
/// no annotations are found or inadmissible constructs occur.
static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
bool annotated = false;
- for (OpOperand *t : op.getInputAndOutputOperands()) {
- auto map = op.getMatchingIndexingMap(t);
- auto enc = getSparseTensorEncoding(t->get().getType());
+ for (OpOperand &t : op->getOpOperands()) {
+ auto map = op.getMatchingIndexingMap(&t);
+ auto enc = getSparseTensorEncoding(t.get().getType());
if (enc)
annotated = true;
- assert(map.getNumResults() == op.getRank(t));
+ assert(map.getNumResults() == op.getRank(&t));
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
- unsigned tensor = t->getOperandNumber();
+ unsigned tensor = t.getOperandNumber();
AffineExpr a = map.getResult(toOrigDim(enc, d));
if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d)))
return false; // inadmissible affine expression
@@ -291,13 +291,13 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
auto iteratorTypes = op.getIteratorTypesArray();
// Iterate over the indexing maps of every tensor in the tensor expression.
- for (OpOperand *t : op.getInputAndOutputOperands()) {
+ for (OpOperand &t : op->getOpOperands()) {
// Skip tensor during cycle resolution.
- if (t == skip)
+ if (&t == skip)
continue;
// Get map and encoding.
- auto map = op.getMatchingIndexingMap(t);
- auto enc = getSparseTensorEncoding(t->get().getType());
+ auto map = op.getMatchingIndexingMap(&t);
+ auto enc = getSparseTensorEncoding(t.get().getType());
assert(map.getNumDims() == n);
// Skip dense tensor constraints when not requested.
if (!(mask & SortMask::kIncludeDense) && !enc)
@@ -314,7 +314,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
// Push unrelated loops into sparse iteration space, so these
// will be skipped more often.
if (mask & SortMask::kIncludeUndef) {
- unsigned tensor = t->getOperandNumber();
+ unsigned tensor = t.getOperandNumber();
for (unsigned i = 0; i < n; i++)
if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) {
@@ -534,16 +534,16 @@ static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder,
static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
linalg::GenericOp op) {
Location loc = op.getLoc();
- assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
+ assert(op->getNumOperands() == op.getNumInputs() + 1);
// For every tensor, find lower and upper bound on dimensions, set the
// same bounds on loop indices, and obtain dense or sparse buffer(s).
auto dynShape = {ShapedType::kDynamicSize};
SmallVector<Value, 4> args;
- for (OpOperand *t : op.getInputAndOutputOperands()) {
- unsigned tensor = t->getOperandNumber();
- auto shape = op.getShape(t);
- auto map = op.getMatchingIndexingMap(t);
- auto enc = getSparseTensorEncoding(t->get().getType());
+ for (OpOperand &t : op->getOpOperands()) {
+ unsigned tensor = t.getOperandNumber();
+ auto shape = op.getShape(&t);
+ auto map = op.getMatchingIndexingMap(&t);
+ auto enc = getSparseTensorEncoding(t.get().getType());
// Scan all dimensions of current tensor.
args.clear();
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
@@ -560,23 +560,23 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
auto dim = builder.getIndexAttr(d);
codegen.pointers[tensor][idx] =
- builder.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
+ builder.create<ToPointersOp>(loc, ptrTp, t.get(), dim);
codegen.indices[tensor][idx] =
- builder.create<ToIndicesOp>(loc, indTp, t->get(), dim);
+ builder.create<ToIndicesOp>(loc, indTp, t.get(), dim);
} else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) {
// Singleton dimension, fetch indices.
auto indTp =
MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
auto dim = builder.getIndexAttr(d);
codegen.indices[tensor][idx] =
- builder.create<ToIndicesOp>(loc, indTp, t->get(), dim);
+ builder.create<ToIndicesOp>(loc, indTp, t.get(), dim);
} else {
// Dense dimension, nothing to fetch.
assert(merger.isDimLevelType(tensor, idx, DimLvlType::kDense));
}
// Find upper bound in current dimension.
unsigned p = toOrigDim(enc, d);
- Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p);
+ Value up = linalg::createOrFoldDimOp(builder, loc, t.get(), p);
if (ShapedType::isDynamic(shape[p]))
args.push_back(up);
assert(codegen.highs[tensor][idx] == nullptr);
@@ -585,21 +585,21 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
// Perform the required bufferization. Dense inputs materialize
// from the input tensors. Dense outputs need special handling.
// Sparse inputs use sparse primitives to obtain the values.
- Type elementType = getElementTypeOrSelf(t->get().getType());
+ Type elementType = getElementTypeOrSelf(t.get().getType());
if (!enc) {
// Non-annotated dense tensors.
auto denseTp = MemRefType::get(shape, elementType);
if (tensor < op.getNumInputs())
codegen.buffers[tensor] =
- builder.create<bufferization::ToMemrefOp>(loc, denseTp, t->get());
+ builder.create<bufferization::ToMemrefOp>(loc, denseTp, t.get());
else
codegen.buffers[tensor] =
genOutputBuffer(codegen, builder, op, denseTp, args);
- } else if (t != codegen.sparseOut) {
+ } else if (&t != codegen.sparseOut) {
// Annotated sparse tensors (not involved in output).
auto sparseTp = MemRefType::get(dynShape, elementType);
codegen.buffers[tensor] =
- builder.create<ToValuesOp>(loc, sparseTp, t->get());
+ builder.create<ToValuesOp>(loc, sparseTp, t.get());
}
}
}
@@ -845,15 +845,15 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder,
return val;
}
// Load during insertion.
- OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
- if (t == codegen.sparseOut) {
+ OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
+ if (&t == codegen.sparseOut) {
if (codegen.redCustom != -1u)
- return genInsertionLoadReduce(merger, codegen, builder, op, t);
- return genInsertionLoad(codegen, builder, op, t);
+ return genInsertionLoadReduce(merger, codegen, builder, op, &t);
+ return genInsertionLoad(codegen, builder, op, &t);
}
// Actual load.
SmallVector<Value, 4> args;
- Value ptr = genSubscript(codegen, builder, op, t, args);
+ Value ptr = genSubscript(codegen, builder, op, &t, args);
if (codegen.curVecLength > 1)
return genVectorLoad(codegen, builder, ptr, args);
return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
@@ -1093,9 +1093,9 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
if (merger.exp(exp).kind == Kind::kTensor) {
// Inspect tensor indices.
bool atLevel = ldx == -1u;
- OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
- auto map = op.getMatchingIndexingMap(t);
- auto enc = getSparseTensorEncoding(t->get().getType());
+ OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
+ auto map = op.getMatchingIndexingMap(&t);
+ auto enc = getSparseTensorEncoding(t.get().getType());
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
AffineExpr a = map.getResult(toOrigDim(enc, d));
if (!isInvariantAffine(codegen, a, ldx, atLevel))
@@ -1105,7 +1105,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
if (!atLevel)
return;
OpOperand *lhs = op.getOutputOperand(0);
- if (lhs == t) {
+ if (lhs == &t) {
// Start or end a scalarized reduction
if (atStart) {
Kind kind = merger.exp(last).kind;
@@ -1288,9 +1288,9 @@ static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
/// This prevents effective vectorization.
static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
unsigned idx) {
- for (OpOperand *t : op.getInputAndOutputOperands()) {
- if (!getSparseTensorEncoding(t->get().getType())) {
- auto map = op.getMatchingIndexingMap(t);
+ for (OpOperand &t : op->getOpOperands()) {
+ if (!getSparseTensorEncoding(t.get().getType())) {
+ auto map = op.getMatchingIndexingMap(&t);
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
AffineExpr a = map.getResult(d);
// Report non-unit stride if innermost index appears at an outer
@@ -1856,7 +1856,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// information for all tensors to loop indices in the kernel.
if (op.getNumOutputs() != 1)
return failure();
- unsigned numTensors = op.getNumInputsAndOutputs();
+ unsigned numTensors = op->getNumOperands();
unsigned numLoops = op.getNumLoops();
Merger merger(numTensors, numLoops);
if (!findSparseAnnotations(merger, op))
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 187a6c0b188b2..b8f6a9385583d 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -910,10 +910,10 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// argument is considered a tensor, indexed by the implicit loop
// bounds. This includes rank-0 tensor arguments.
if (arg.getOwner()->getParentOp() == op) {
- OpOperand *t = op.getInputAndOutputOperands()[argN];
- if (!op.isScalar(t))
+ OpOperand &t = op->getOpOperand(argN);
+ if (!op.isScalar(&t))
return addExp(kTensor, argN);
- v = t->get(); // get scalar value
+ v = t.get(); // get scalar value
}
// Any other argument (marked as scalar argument for the generic op
// or belonging to an enveloping op) is considered invariant.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 43589f7f36812..2062c65a45a86 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -275,7 +275,7 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
// -----
// CHECK-LABEL: func @remove_deadargs_generic_basic
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
// CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 02471b17a2fbe..f751ddff7df0e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -121,26 +121,6 @@ func.func @generic(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offset: ?>>
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>)
// CHECK-SAME: {foo = 1 : i64}
-func.func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
- %arg1: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
- %cst = arith.constant 0.0 : f32
- linalg.generic #trait_0
- ins(%arg0, %cst : tensor<?x?xvector<3x4xi4>>, f32)
- outs(%arg1 : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>)
- attrs = {foo = 1} {
- ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) :
- linalg.yield %1 : f32
- }
- return
-}
-// CHECK-LABEL: func @generic_with_tensor_input
-// CHECK: linalg.generic {
-// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
-// CHECK-SAME: library_call = "some_external_function_name_1"}
-// CHECK-SAME: ins({{.*}}, {{.*}} : tensor<?x?xvector<3x4xi4>>, f32)
-// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>)
-// CHECK-SAME: {foo = 1 : i64}
-
// -----
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
@@ -300,27 +280,19 @@ func.func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offs
func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
%ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
- -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ -> (tensor<?x?x?xf32>)
{
linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
outs(%c3: memref<?x?x?xf32>)
- linalg.batch_matmul ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
- outs(%c3: memref<?x?x?xf32>)
%res1 = linalg.batch_matmul
ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
outs(%tc3: tensor<?x?x?xf32>)
-> tensor<?x?x?xf32>
- %res2 = linalg.batch_matmul
- ins(%ta3, %b3: tensor<?x?x?xf32>, memref<?x?x?xf32>)
- outs(%tc3: tensor<?x?x?xf32>)
- -> tensor<?x?x?xf32>
- return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
+ return %res1 : tensor<?x?x?xf32>
}
// CHECK-LABEL: func @named_ops
// CHECK: linalg.batch_matmul
// CHECK: linalg.batch_matmul
-// CHECK: linalg.batch_matmul
-// CHECK: linalg.batch_matmul
// -----
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index 0119516f272c0..5807726dd73b4 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -26,7 +26,7 @@ static void addOperands(Operation *op, SetVector<Value> &operandSet) {
return;
TypeSwitch<Operation *, void>(op)
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
- SmallVector<Value> inputOperands = linalgOp.getInputOperands();
+ SmallVector<Value> inputOperands{linalgOp.getInputOperands()};
operandSet.insert(inputOperands.begin(), inputOperands.end());
})
.Default([&](Operation *operation) {
@@ -147,7 +147,7 @@ struct TestLinalgElementwiseFusion
if (expandOp->hasOneUse()) {
OpOperand &use = *expandOp->getUses().begin();
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
- if (linalgOp && linalgOp.isOutputTensor(&use))
+ if (linalgOp && linalgOp.isOutput(&use))
return true;
}
return false;
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index c5b27c53e8ccb..3c62496f09d24 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -38,14 +38,14 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
// Tile and Fuse for tensors inputs (TODO: all tensor operands).
bool changed = false;
for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
- for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
- if (opOperand->get().getType().isa<MemRefType>()) {
+ for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+ if (opOperand.get().getType().isa<MemRefType>()) {
// TODO: LinalgDependenceGraph should be able to update itself.
// The current naive and expensive reconstruction of the graph should be
// removed.
linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalgOps);
- auto info = fuseProducerOfBuffer(b, *opOperand, graph);
+ auto info = fuseProducerOfBuffer(b, opOperand, graph);
if (failed(info))
continue;
auto *originalOp = info->originalProducer.getOperation();
@@ -54,11 +54,11 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
std::find(linalgOps.begin(), linalgOps.end(), originalOp);
*originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
changed = true;
- } else if (opOperand->get().getType().isa<RankedTensorType>()) {
+ } else if (opOperand.get().getType().isa<RankedTensorType>()) {
// Tile and Fuse tensor input.
- if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
+ if (opOperand.getOperandNumber() >= linalgOp.getNumInputs())
continue;
- auto info = fuseProducerOfTensor(b, *opOperand);
+ auto info = fuseProducerOfTensor(b, opOperand);
if (failed(info))
continue;
auto *originalOp = info->originalProducer.getOperation();
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 85c5f3204b722..79ed068eed9c1 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2835,9 +2835,10 @@ def TestLinalgConvOp :
return "";
}
- // To conform with interface requirement on operand naming.
- mlir::ValueRange inputs() { return getInputs(); }
- mlir::ValueRange outputs() { return getOutputs(); }
+ std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ int64_t getNumOperands = this->getNumOperands();
+ return {getNumOperands - 1, getNumOperands};
+ }
}];
}
@@ -2894,9 +2895,10 @@ def TestLinalgFillOp :
return "";
}
- // To conform with interface requirement on operand naming.
- mlir::ValueRange inputs() { return getInputs(); }
- mlir::ValueRange outputs() { return getOutputs(); }
+ std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ int64_t getNumOperands = this->getNumOperands();
+ return {getNumOperands - 1, getNumOperands};
+ }
}];
}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 8156bb97a32f3..d8e10efaaba04 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -563,6 +563,11 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
return regionBuilder;
}
+ std::pair<int64_t, int64_t> getOutputsPositionRange() {{
+ int64_t getNumOperands = this->getNumOperands();
+ return {{getNumOperands - 1, getNumOperands};
+ }
+
// Generic methods.
static unsigned getNumRegionArgs();
std::string getLibraryCallName();
@@ -638,8 +643,8 @@ ArrayAttr {0}::getIndexingMaps() {{
AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
getNumParallelLoops(), context);
SmallVector<AffineMap> indexingMaps;
- for (OpOperand *opOperand : getInputAndOutputOperands())
- indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap);
+ for (OpOperand &opOperand : getOperation()->getOpOperands())
+ indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap);
return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
}
)FMT";
@@ -654,10 +659,9 @@ LogicalResult {0}::fold(ArrayRef<Attribute>,
}
void {0}::getEffects(SmallVectorImpl<
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
- SmallVector<Value> inputBuffers = getInputBufferOperands();
- SmallVector<Value> outputBuffers = getOutputBufferOperands();
+ if (hasTensorSemantics()) return;
getGenericEffectsImpl(effects,
- getOperation()->getResults(), inputBuffers, outputBuffers);
+ getOperation()->getResults(), getInputOperands(), getOutputOperands());
}
)FMT";
More information about the Mlir-commits
mailing list