[Mlir-commits] [mlir] 7594f50 - [mlir][linalg] Cleanup LinalgOp usage in fusion (NFC).
Tobias Gysi
llvmlistbot at llvm.org
Tue Jun 1 01:41:46 PDT 2021
Author: Tobias Gysi
Date: 2021-06-01T08:21:30Z
New Revision: 7594f5028a11c68bcfdf631928ab44889127fab7
URL: https://github.com/llvm/llvm-project/commit/7594f5028a11c68bcfdf631928ab44889127fab7
DIFF: https://github.com/llvm/llvm-project/commit/7594f5028a11c68bcfdf631928ab44889127fab7.diff
LOG: [mlir][linalg] Cleanup LinalgOp usage in fusion (NFC).
Replace the uses of deprecated Structured Op Interface methods in Fusion.cpp. This patch is based on https://reviews.llvm.org/D103394.
Differential Revision: https://reviews.llvm.org/D103437
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 465f933f862c..0263bcb70844 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -69,10 +69,9 @@ struct ShapeDimension {
static ShapeDimension
getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
bool fromSubViewOpOnly = false) {
- auto maps = op.indexing_maps();
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
- for (auto en : llvm::enumerate(op.getShapedOperands())) {
+ for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
// The method `getRangeFromOperandShape` requires using SubViewOp or
// SubTensorOps. If the value isnt defined from there continue.
// todo: The method should be adapted to get the values from
@@ -80,27 +79,26 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
// currently returns a `linalg.range`. The fix here is to move this op to
// `std` dialect and add the method to `ViewInterface`.
if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>(
- en.value().getDefiningOp()))
+ opOperand->get().getDefiningOp()))
continue;
- unsigned idx = en.index();
- auto map = maps[idx].cast<AffineMapAttr>().getValue();
- LLVM_DEBUG(llvm::dbgs()
- << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
+ AffineMap map = op.getTiedIndexingMap(opOperand);
+ LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
+ << opOperand->getOperandNumber() << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "getShapeDefiningLoopRange map: " << map << "\n");
- Value shape = en.value();
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
- for (auto en2 : llvm::enumerate(map.getResults())) {
- auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
+ for (auto en : llvm::enumerate(map.getResults())) {
+ auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
if (!dimExpr)
continue;
- if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
+ if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) {
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
<< loopDepth << "\n");
- LLVM_DEBUG(llvm::dbgs()
- << "getShapeDefiningLoopRange shape: " << shape << "\n");
- return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
+ LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
+ << opOperand->get() << "\n");
+ return ShapeDimension{opOperand->get(),
+ static_cast<unsigned>(en.index())};
}
}
}
@@ -122,26 +120,24 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
// would need to add the intermediate results to `linalg.yield`. After that a
// canonicalization pass would move the unused output args of the `tiled_loop`
// to the `input` section.
-static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
+static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) {
auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
if (!tiledLoop)
- return llvm::to_vector<4>(producer.getShapedOperands());
+ return producer.getInputAndOutputOperands();
- SmallVector<Value, 4> tiledOperands;
+ SmallVector<Value> tiledOperands;
assert(producer.hasTensorSemantics() &&
"only fusion on tensors is currently supported for TiledLinalgOp");
- for (auto producerInput : producer.getInputTensors()) {
- OpOperand *addedInput = tiledLoop.findInputOperand(producerInput);
+ for (OpOperand *producerInput : producer.getInputTensorOperands()) {
+ OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get());
if (addedInput == nullptr)
- addedInput = &tiledLoop.appendInputOperand(b, producerInput);
+ addedInput = &tiledLoop.appendInputOperand(b, producerInput->get());
BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
tiledOperands.push_back(addedBlockArg);
}
- for (auto &en : llvm::enumerate(producer.getOutputTensors())) {
- Value producerOutput = en.value();
-
- Value result = producer->getResult(en.index());
+ for (OpOperand *producerOutput : producer.getOutputTensorOperands()) {
+ OpResult result = producer.getTiedOpResult(producerOutput);
OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
@@ -152,10 +148,11 @@ static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
int opNumber = isInput ? resultInputOperand->getOperandNumber()
: resultOutputOperand->getOperandNumber();
- OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput);
+ OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput->get());
if (addedOutput == nullptr)
- addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput)
- : &tiledLoop.appendOutputOperand(b, producerOutput);
+ addedOutput =
+ isInput ? &tiledLoop.appendInputOperand(b, producerOutput->get())
+ : &tiledLoop.appendOutputOperand(b, producerOutput->get());
OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
@@ -200,7 +197,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
}
SmallVector<Value, 8> clonedShapes;
- clonedShapes.reserve(producer.getNumShapedOperands());
+ clonedShapes.reserve(producer.getNumInputsAndOutputs());
// Compute subranges for all tensor input/output operands.
clonedShapes.append(makeTiledShapes(b, loc, producer,
@@ -267,16 +264,9 @@ static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
llvm_unreachable("SubviewOp or SubTensorOp expected");
}
-/// Fuses the producer of `producerIdx` into the loop immediately enclosing
-/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
-/// is needed just before the `consumer.
-///
-/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
-/// 2 cases:
-/// 1. Buffer case: `producerIdx` is the index of the buffer in
-/// `producer.getOutputBuffers()`.
-/// 2. Tensor case: `producerIdx` is the index of the tensor in
-/// `producer.getResults()`.
+/// Fuses the producer into the loop immediately enclosing the consumer.
+/// This is achieved by "recomputing" the producer at the time it
+/// is needed just before the consumer.
static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
OpOperand &consumerOpOperand) {
LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
@@ -548,9 +538,10 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(consumerOp);
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
+ OpOperand *opOperand =
+ producerOp.getOutputOperand(producerOpResult.getResultNumber());
LinalgOp fusedProducer =
- fuse(b, producerOp,
- producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()),
+ fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand),
consumerOpOperand);
// Replace use.
@@ -770,9 +761,9 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
FusableOpDependencesTy fusableDependences;
DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
for (LinalgOp op : reverse(ops)) {
- for (OpOperand &opOperand : op.getShapedOpOperands()) {
+ for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
- fusableDependence = findFusableProducer(opOperand, dependenceGraph);
+ fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
if (!fusableDependence)
continue;
// Canonicalize indexed generic ops before fusion.
@@ -905,10 +896,16 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
// To keep the second type of information while letting the unfused op die
// unused, we need to forward the producer output operand.
if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
- for (auto &operand : forOp.getIterOpOperands())
- if (auto opResult = operand.get().dyn_cast<OpResult>())
- if (opResult.getOwner() == origOp)
- operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
+ for (auto &operand : forOp.getIterOpOperands()) {
+ if (auto opResult = operand.get().dyn_cast<OpResult>()) {
+ if (opResult.getOwner() == origOp) {
+ Value output =
+ origOp.getOutputOperand(opResult.getResultNumber())->get();
+ assert(output.getType().isa<RankedTensorType>());
+ operand.set(output);
+ }
+ }
+ }
}
}
return fusedOps;
More information about the Mlir-commits
mailing list