[Mlir-commits] [mlir] 7594f50 - [mlir][linalg] Cleanup LinalgOp usage in fusion (NFC).

Tobias Gysi llvmlistbot at llvm.org
Tue Jun 1 01:41:46 PDT 2021


Author: Tobias Gysi
Date: 2021-06-01T08:21:30Z
New Revision: 7594f5028a11c68bcfdf631928ab44889127fab7

URL: https://github.com/llvm/llvm-project/commit/7594f5028a11c68bcfdf631928ab44889127fab7
DIFF: https://github.com/llvm/llvm-project/commit/7594f5028a11c68bcfdf631928ab44889127fab7.diff

LOG: [mlir][linalg] Cleanup LinalgOp usage in fusion (NFC).

Replace the uses of deprecated Structured Op Interface methods in Fusion.cpp. This patch is based on https://reviews.llvm.org/D103394.

Differential Revision: https://reviews.llvm.org/D103437

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 465f933f862c..0263bcb70844 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -69,10 +69,9 @@ struct ShapeDimension {
 static ShapeDimension
 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
                           bool fromSubViewOpOnly = false) {
-  auto maps = op.indexing_maps();
   // Iterate over the inputs and outputs in order.
   // Extract the subranges from the linearized ranges.
-  for (auto en : llvm::enumerate(op.getShapedOperands())) {
+  for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
     // The method `getRangeFromOperandShape` requires using SubViewOp or
     // SubTensorOps. If the value isnt defined from there continue.
     // todo: The method should be adapted to get the values from
@@ -80,27 +79,26 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
     // currently returns a `linalg.range`. The fix here is to move this op to
     // `std` dialect and add the method to `ViewInterface`.
     if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>(
-                                 en.value().getDefiningOp()))
+                                 opOperand->get().getDefiningOp()))
       continue;
 
-    unsigned idx = en.index();
-    auto map = maps[idx].cast<AffineMapAttr>().getValue();
-    LLVM_DEBUG(llvm::dbgs()
-               << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
+    AffineMap map = op.getTiedIndexingMap(opOperand);
+    LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
+                            << opOperand->getOperandNumber() << "\n");
     LLVM_DEBUG(llvm::dbgs()
                << "getShapeDefiningLoopRange map: " << map << "\n");
-    Value shape = en.value();
     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
-    for (auto en2 : llvm::enumerate(map.getResults())) {
-      auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
+    for (auto en : llvm::enumerate(map.getResults())) {
+      auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
       if (!dimExpr)
         continue;
-      if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
+      if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) {
         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
                                 << loopDepth << "\n");
-        LLVM_DEBUG(llvm::dbgs()
-                   << "getShapeDefiningLoopRange shape: " << shape << "\n");
-        return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
+        LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
+                                << opOperand->get() << "\n");
+        return ShapeDimension{opOperand->get(),
+                              static_cast<unsigned>(en.index())};
       }
     }
   }
@@ -122,26 +120,24 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
 // would need to add the intermediate results to `linalg.yield`. After that a
 // canonicalization pass would move the unused output args of the `tiled_loop`
 // to the `input` section.
-static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
+static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) {
   auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
   if (!tiledLoop)
-    return llvm::to_vector<4>(producer.getShapedOperands());
+    return producer.getInputAndOutputOperands();
 
-  SmallVector<Value, 4> tiledOperands;
+  SmallVector<Value> tiledOperands;
   assert(producer.hasTensorSemantics() &&
          "only fusion on tensors is currently supported for TiledLinalgOp");
 
-  for (auto producerInput : producer.getInputTensors()) {
-    OpOperand *addedInput = tiledLoop.findInputOperand(producerInput);
+  for (OpOperand *producerInput : producer.getInputTensorOperands()) {
+    OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get());
     if (addedInput == nullptr)
-      addedInput = &tiledLoop.appendInputOperand(b, producerInput);
+      addedInput = &tiledLoop.appendInputOperand(b, producerInput->get());
     BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
     tiledOperands.push_back(addedBlockArg);
   }
-  for (auto &en : llvm::enumerate(producer.getOutputTensors())) {
-    Value producerOutput = en.value();
-
-    Value result = producer->getResult(en.index());
+  for (OpOperand *producerOutput : producer.getOutputTensorOperands()) {
+    OpResult result = producer.getTiedOpResult(producerOutput);
     OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
     OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
     assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
@@ -152,10 +148,11 @@ static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
     int opNumber = isInput ? resultInputOperand->getOperandNumber()
                            : resultOutputOperand->getOperandNumber();
 
-    OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput);
+    OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput->get());
     if (addedOutput == nullptr)
-      addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput)
-                            : &tiledLoop.appendOutputOperand(b, producerOutput);
+      addedOutput =
+          isInput ? &tiledLoop.appendInputOperand(b, producerOutput->get())
+                  : &tiledLoop.appendOutputOperand(b, producerOutput->get());
 
     OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
     auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
@@ -200,7 +197,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
   }
 
   SmallVector<Value, 8> clonedShapes;
-  clonedShapes.reserve(producer.getNumShapedOperands());
+  clonedShapes.reserve(producer.getNumInputsAndOutputs());
 
   // Compute subranges for all tensor input/output operands.
   clonedShapes.append(makeTiledShapes(b, loc, producer,
@@ -267,16 +264,9 @@ static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
   llvm_unreachable("SubviewOp or SubTensorOp expected");
 }
 
-/// Fuses the producer of `producerIdx` into the loop immediately enclosing
-/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
-/// is needed just before the `consumer.
-///
-/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
-/// 2 cases:
-///   1. Buffer case: `producerIdx` is the index of the buffer in
-///      `producer.getOutputBuffers()`.
-///   2. Tensor case: `producerIdx` is the index of the tensor in
-///      `producer.getResults()`.
+/// Fuses the producer into the loop immediately enclosing the consumer.
+/// This is achieved by "recomputing" the producer at the time it
+/// is needed just before the consumer.
 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
                      OpOperand &consumerOpOperand) {
   LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
@@ -548,9 +538,10 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(consumerOp);
   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
+  OpOperand *opOperand =
+      producerOp.getOutputOperand(producerOpResult.getResultNumber());
   LinalgOp fusedProducer =
-      fuse(b, producerOp,
-           producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()),
+      fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand),
            consumerOpOperand);
 
   // Replace use.
@@ -770,9 +761,9 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
   FusableOpDependencesTy fusableDependences;
   DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
   for (LinalgOp op : reverse(ops)) {
-    for (OpOperand &opOperand : op.getShapedOpOperands()) {
+    for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
-          fusableDependence = findFusableProducer(opOperand, dependenceGraph);
+          fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
       if (!fusableDependence)
         continue;
       // Canonicalize indexed generic ops before fusion.
@@ -905,10 +896,16 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
     // To keep the second type of information while letting the unfused op die
     // unused, we need to forward the producer output operand.
     if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
-      for (auto &operand : forOp.getIterOpOperands())
-        if (auto opResult = operand.get().dyn_cast<OpResult>())
-          if (opResult.getOwner() == origOp)
-            operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
+      for (auto &operand : forOp.getIterOpOperands()) {
+        if (auto opResult = operand.get().dyn_cast<OpResult>()) {
+          if (opResult.getOwner() == origOp) {
+            Value output =
+                origOp.getOutputOperand(opResult.getResultNumber())->get();
+            assert(output.getType().isa<RankedTensorType>());
+            operand.set(output);
+          }
+        }
+      }
     }
   }
   return fusedOps;


        


More information about the Mlir-commits mailing list