[Mlir-commits] [mlir] d27ab5c - [mlir][Linalg] NFC: Refactor fusion on tensors to enable extending

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 23 13:42:17 PDT 2020


Author: MaheshRavishankar
Date: 2020-04-23T13:41:47-07:00
New Revision: d27ab5c2409b0223ffb6b7ebcb75cd1bde4ac231

URL: https://github.com/llvm/llvm-project/commit/d27ab5c2409b0223ffb6b7ebcb75cd1bde4ac231
DIFF: https://github.com/llvm/llvm-project/commit/d27ab5c2409b0223ffb6b7ebcb75cd1bde4ac231.diff

LOG: [mlir][Linalg] NFC: Refactor fusion on tensors to enable extending
it to fusing different kinds of linalg operations on tensors.

The implementation of fusion on tensor was initially planned for just
GenericOps (and maybe IndexedGenericOps). With addition of
linalg.tensor_reshape, and potentially other such non-structured ops,
refactor the existing implementation to allow easier specification of
fusion between different linalg operations on tensors.

Differential Revision: https://reviews.llvm.org/D78463

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 7a1e18398cad..7dea577f0a49 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -19,6 +19,7 @@ namespace mlir {
 class AffineExpr;
 class AffineMap;
 class OperationFolder;
+class PatternRewriter;
 
 namespace linalg {
 class LinalgDependenceGraph;
@@ -71,11 +72,11 @@ Optional<FusionInfo> fuseProducerOf(OpBuilder &b, LinalgOp consumer,
                                     const LinalgDependenceGraph &graph,
                                     OperationFolder *folder = nullptr);
 
-/// Fuse linalg operation on tensors, where the result of the producer is used
-/// as the operand of the consumer at position `consumerIdx`.
-Optional<LinalgOp> fuseTensorOps(OpBuilder &b, LinalgOp producer,
-                                 LinalgOp consumer, unsigned consumerIdx,
-                                 OperationFolder *folder = nullptr);
+/// Fuse linalg operation on tensors, with the producer of the operand at
+/// position `consumerIdx` of the consumer.
+Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
+                         unsigned consumerIdx,
+                         OperationFolder *folder = nullptr);
 
 /// Returns the linearized list of all view dimensions in a linalgOp. Applying
 /// the inverse, concatenated loopToOperandRangeMaps to this list allows the

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index b2100419a114..1184b5f87ea6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -360,154 +360,6 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
   return llvm::None;
 }
 
-/// Checks if two Generic ops are fusible, when one is a producer and another is
-/// a consumer (with the result of the producer being the `consumerIdx` operand
-/// of the consumer).
-static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer,
-                                unsigned consumerIdx) {
-  // Verify that the producer and consumer are ops on tensors.
-  if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
-    return false;
-
-  auto producerOp = dyn_cast<linalg::GenericOp>(producer.getOperation());
-  auto consumerOp = dyn_cast<linalg::GenericOp>(consumer.getOperation());
-  // Verify that
-  // - the producer and consumers are generic ops,
-  // - only handle cases where the producer has a single return value,
-  // - the producer return value should be the same as argument at `consumerIdx`
-  //   of the consumer,
-  // - the producer has all "parallel" iterator type.
-  // - only handle ops that use regions for specifying the scalar operations.
-  if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 ||
-      producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) ||
-      producerOp.getNumParallelLoops() != producerOp.getNumLoops())
-    return false;
-
-  // Get the consumer index map. The number of results of the consumer index map
-  // must match the number of loops of the producer.
-  AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx);
-  if (consumerIndexMap.getNumResults() != producerOp.getNumLoops())
-    return false;
-
-  // Finally the index_map for the result must be invertible. For now just
-  // verify it is a permutation.
-  AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0);
-  return producerResultIndexMap.isPermutation();
-}
-
-/// Computes the indexing maps for arguments of a producer generic op when the
-/// result of the producer is fused with the consumer.
-/// - consumerIndexMap is the indexing_map for the argument in the consumer op
-///   that is the result of the producer op.
-/// - invProducerResultIndexMap is the inverse of the indexing_map for the
-///   result in the producer op.
-/// - producerArgIndexMap is the indexing_map of the argument of the producer
-///   op.
-/// The result is the indexing_map to use for the producer argument when the
-/// producer and consumer ops are fused.
-static AffineMap computeProducerArgMap(AffineMap consumerIndexMap,
-                                       AffineMap invProducerResultIndexMap,
-                                       AffineMap producerArgIndexMap) {
-  // t1 is map from producer result tensor index -> producer arg tensor index.
-  auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap);
-  // The return is map from consumer loop -> producer arg tensor index,
-  // i.e. indexing_map for the producer argument in the fused operation.
-  return t1.compose(consumerIndexMap);
-}
-
-Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer,
-                                               LinalgOp consumer,
-                                               unsigned consumerIdx,
-                                               OperationFolder *folder) {
-  if (!areTensorOpsFusible(producer, consumer, consumerIdx))
-    return {};
-
-  MLIRContext *context = b.getContext();
-  auto producerOp = cast<linalg::GenericOp>(producer.getOperation());
-  auto consumerOp = cast<linalg::GenericOp>(consumer.getOperation());
-  AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx);
-  AffineMap invProducerResultIndexMap =
-      inversePermutation(producerOp.getOutputIndexingMap(0));
-  if (!invProducerResultIndexMap)
-    return {};
-
-  // Compute the fused op operandslist by replacing the operand corresponding to
-  // the result of the producer, with the operands of the producer.
-  unsigned fusedArgsIn =
-      producerOp.getNumInputs() + consumerOp.getNumInputs() - 1;
-  auto fusedArgsOut = consumerOp.getNumOutputs();
-  SmallVector<Value, 2> fusedOperandsList(consumerOp.getOperands());
-  fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx));
-  fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut);
-  fusedOperandsList.insert(
-      std::next(fusedOperandsList.begin(), consumerIdx),
-      producerOp.operand_begin(),
-      std::next(producerOp.operand_begin(), producerOp.getNumInputs()));
-
-  // Compute the fused indexing_maps of the operands/results of the fused op.
-  SmallVector<Attribute, 2> fusedIndexingMapAttrs;
-  fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut);
-  fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(),
-                               consumerOp.indexing_maps().end());
-  fusedIndexingMapAttrs.erase(
-      std::next(fusedIndexingMapAttrs.begin(), consumerIdx));
-  auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx);
-  for (auto producerArgIndexAttr :
-       llvm::enumerate(producerOp.indexing_maps())) {
-    if (producerArgIndexAttr.index() == producerOp.getNumInputs())
-      break;
-    auto composedIndexMap = computeProducerArgMap(
-        consumerIndexMap, invProducerResultIndexMap,
-        producerArgIndexAttr.value().cast<AffineMapAttr>().getValue());
-    insertPos = std::next(fusedIndexingMapAttrs.insert(
-        insertPos, AffineMapAttr::get(composedIndexMap)));
-  }
-
-  // Generate the fused op.
-  auto fusedLinalgOp = b.create<GenericOp>(
-      UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList,
-      b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut),
-      b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(),
-      /*doc=*/nullptr,
-      /*library_call=*/nullptr);
-
-  // Build the region of the fused op.
-  auto &fusedOpRegion = fusedLinalgOp.region();
-  Block &producerOpBlock = producerOp.region().front();
-  Block &consumerOpBlock = consumerOp.region().front();
-  Block *fusedBlock = new Block();
-  fusedOpRegion.push_back(fusedBlock);
-  BlockAndValueMapping mapper;
-  // Map the arguments for the unmodified args from the consumer.
-  for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) {
-    if (consumerOpArg.index() == consumerIdx) {
-      // Map the arguments for the args from the producer.
-      for (auto producerOpArg : producerOpBlock.getArguments())
-        mapper.map(producerOpArg,
-                   fusedBlock->addArgument(producerOpArg.getType()));
-      continue;
-    }
-    mapper.map(consumerOpArg.value(),
-               fusedBlock->addArgument(consumerOpArg.value().getType()));
-  }
-
-  // Add operations from producer (except the yield operation) to the fused op.
-  for (auto &op : producerOpBlock.getOperations()) {
-    if (auto yieldOp = dyn_cast<YieldOp>(op)) {
-      // Lookup the value the yield operation is mapped to.
-      Value yieldVal = yieldOp.getOperand(0);
-      auto clonedVal = mapper.lookup(yieldVal);
-      mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal);
-      continue;
-    }
-    fusedBlock->push_back(op.clone(mapper));
-  }
-  for (auto &op : consumerOpBlock.getOperations())
-    fusedBlock->push_back(op.clone(mapper));
-
-  return cast<LinalgOp>(fusedLinalgOp.getOperation());
-}
-
 static void fuseLinalgOpsGreedily(FuncOp f) {
   LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
 
@@ -549,33 +401,206 @@ static void fuseLinalgOpsGreedily(FuncOp f) {
   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
 }
 
+//====---------------------------------------------------------------------===//
+// Fusion on Tensor operation.
+//====---------------------------------------------------------------------===//
+
 namespace {
 
-/// Patterns to fuse a generic op, with the producer of its operands.
-struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
+/// Implementation of fusion of generic ops.
+struct FuseGenericOpsOnTensors {
+  static bool isFusible(GenericOp producer, GenericOp consumer,
+                        unsigned consumerIdx) {
+    // Verify that
+    // - the producer has all "parallel" iterator type.
+    if (producer.getNumParallelLoops() != producer.getNumLoops())
+      return false;
+
+    // Get the consumer index map. The number of results of the consumer index
+    // map must match the number of loops of the producer.
+    AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
+    if (consumerIndexMap.getNumResults() != producer.getNumLoops())
+      return false;
+
+    // Finally the index_map for the result must be invertible. For now just
+    // verify it is a permutation.
+    AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+    return producerResultIndexMap.isPermutation();
+  }
 
-  LogicalResult matchAndRewrite(GenericOp op,
-                                PatternRewriter &rewriter) const override {
-    if (!op.hasTensorSemantics())
-      return failure();
+  static Operation *fuse(GenericOp producer, GenericOp consumer,
+                         unsigned consumerIdx, PatternRewriter &rewriter,
+                         OperationFolder *folder = nullptr) {
+    if (!isFusible(producer, consumer, consumerIdx))
+      return nullptr;
+
+    unsigned numFusedOperands = producer.getOperation()->getNumOperands() +
+                                consumer.getOperation()->getNumOperands() - 1;
+
+    // Compute the fused operands list,
+    SmallVector<Value, 2> fusedOperands;
+    fusedOperands.reserve(numFusedOperands);
+    auto consumerOperands = consumer.getOperation()->getOperands();
+    auto producerOperands = producer.getOperation()->getOperands();
+    fusedOperands.assign(consumerOperands.begin(),
+                         std::next(consumerOperands.begin(), consumerIdx));
+    fusedOperands.append(producerOperands.begin(), producerOperands.end());
+    fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
+                         consumerOperands.end());
+
+    // Compute indexing_maps for the fused operation. The indexing_maps for the
+    // operands of the consumers that arent fused are the same. The
+    // indexing_maps for the producers need to be computed based on the
+    // indexing_map of the operand at consumerIdx in the consumer.
+    SmallVector<Attribute, 4> fusedIndexMaps;
+    auto consumerIndexMaps = consumer.indexing_maps();
+    fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumResults());
+    fusedIndexMaps.assign(consumerIndexMaps.begin(),
+                          std::next(consumerIndexMaps.begin(), consumerIdx));
+    // Compute indexing maps for the producer args in the fused operation.
+    computeProducerOperandIndex(
+        producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
+
+    // Append the indexing maps for the remaining consumer operands.
+    fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
+                          consumerIndexMaps.end());
+
+    // Generate the fused op.
+    auto fusedOp = rewriter.create<GenericOp>(
+        rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
+        rewriter.getI64IntegerAttr(fusedOperands.size()),
+        rewriter.getI64IntegerAttr(consumer.getNumResults()),
+        rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(),
+        /*doc=*/nullptr,
+        /*library_call=*/nullptr);
+    generateFusedRegion(rewriter, fusedOp.region(), producer.region(),
+                        consumer.region(), consumerIdx);
+    return fusedOp;
+  }
 
-    // Find the first operand that is defined by another generic op on tensors.
-    for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) {
-      auto definingOp =
-          dyn_cast_or_null<GenericOp>(operand.value().getDefiningOp());
-      if (!definingOp || !definingOp.hasTensorSemantics())
+private:
+  /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
+  /// the `producer` to use in the fused operation given the indexing map of the
+  /// result of the producer in the consumer.
+  static void computeProducerOperandIndex(
+      GenericOp producer, AffineMap fusedConsumerArgIndexMap,
+      SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
+    // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
+    // from consumer loop -> consumer arg tensor index/producer result tensor
+    // index. The fused loop is same as the consumer loop. For each producer arg
+    // the indexing map to be computed is a map from consumer loop -> producer
+    // arg tensor index.
+
+    AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+    // producerResultIndexMap is a map from producer loop -> tensor index.
+    // Compute the inverse to get map from tensor index -> producer loop.
+    // The inverse is a map from producer result tensor index -> producer loop.
+    AffineMap invProducerResultIndexMap =
+        inversePermutation(producerResultIndexMap);
+    assert(invProducerResultIndexMap &&
+           "expected producer result indexig map to be invertible");
+    for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
+      // argMap is a map from producer loop -> producer arg tensor index.
+      AffineMap argMap = producer.getInputIndexingMap(argNum);
+
+      // Compose argMap with invProducerResultIndexMap to get a map from
+      // producer result tensor index -> producer arg tensor index.
+      AffineMap t1 = argMap.compose(invProducerResultIndexMap);
+
+      // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
+      // consumer loop/ fused loop -> producer arg tensor index.
+      AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
+      fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
+    }
+  }
+
+  /// Generate the region of the fused operation. The region of the fused op
+  /// must be empty.
+  static void generateFusedRegion(PatternRewriter &rewriter,
+                                  Region &fusedRegion, Region &producerRegion,
+                                  Region &consumerRegion,
+                                  unsigned consumerIdx) {
+    // Build the region of the fused op.
+    Block &producerBlock = producerRegion.front();
+    Block &consumerBlock = consumerRegion.front();
+    Block *fusedBlock = new Block();
+    fusedRegion.push_back(fusedBlock);
+    BlockAndValueMapping mapper;
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(fusedBlock);
+    // Map the arguments for the unmodified args from the consumer.
+    for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
+      if (consumerArg.index() == consumerIdx) {
+        // Map the arguments for the args from the producer.
+        for (auto producerArg : producerBlock.getArguments())
+          mapper.map(producerArg,
+                     fusedBlock->addArgument(producerArg.getType()));
         continue;
-      auto fusedOp =
-          fuseTensorOps(rewriter, cast<LinalgOp>(definingOp.getOperation()),
-                        cast<LinalgOp>(op.getOperation()), operand.index());
-      if (!fusedOp)
+      }
+      mapper.map(consumerArg.value(),
+                 fusedBlock->addArgument(consumerArg.value().getType()));
+    }
+
+    // Add operations from producer (except the yield operation) to the fused
+    // op.
+    for (auto &op : producerBlock.getOperations()) {
+      if (auto yieldOp = dyn_cast<YieldOp>(op)) {
+        // Lookup the value the yield operation is mapped to.
+        Value yieldVal = yieldOp.getOperand(0);
+        auto clonedVal = mapper.lookup(yieldVal);
+        mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal);
         continue;
-      rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults());
-      if (llvm::all_of(definingOp.getResults(),
-                       [](Value val) -> bool { return val.use_empty(); }))
-        rewriter.eraseOp(definingOp);
-      return success();
+      }
+      rewriter.clone(op, mapper);
+    }
+    for (auto &op : consumerBlock.getOperations())
+      rewriter.clone(op, mapper);
+  }
+};
+} // namespace
+
+Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
+                                       Operation *consumer,
+                                       unsigned consumerIdx,
+                                       OperationFolder *folder) {
+  if (consumerIdx >= consumer->getNumOperands())
+    return nullptr;
+  Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
+  if (!producer || producer->getNumResults() != 1)
+    return nullptr;
+
+  if (GenericOp genericOp = dyn_cast<GenericOp>(consumer)) {
+    if (!genericOp.hasTensorSemantics())
+      return nullptr;
+    if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
+      if (genericOpProducer.hasTensorSemantics())
+        return FuseGenericOpsOnTensors::fuse(genericOpProducer, genericOp,
+                                             consumerIdx, rewriter, folder);
+    }
+  }
+  return nullptr;
+}
+
+namespace {
+/// Patterns to fuse a generic op, with the producer of its operands.
+template <typename LinalgOpTy>
+struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
+  using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(LinalgOpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Find the first operand that is defined by another generic op on tensors.
+    for (auto operandNum :
+         llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) {
+      Operation *producer =
+          op.getOperation()->getOperand(operandNum).getDefiningOp();
+      if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) {
+        rewriter.replaceOp(op, fusedOp->getResults());
+        if (producer && llvm::all_of(producer->getResults(),
+                                     [](Value val) { return val.use_empty(); }))
+          rewriter.eraseOp(producer);
+        return success();
+      }
     }
     return failure();
   }
@@ -587,7 +612,7 @@ struct FusionOfTensorOpsPass
   void runOnOperation() override {
     OwningRewritePatternList patterns;
     Operation *op = getOperation();
-    patterns.insert<FuseGenericTensorOps>(op->getContext());
+    patterns.insert<FuseTensorOps<GenericOp>>(op->getContext());
     applyPatternsAndFoldGreedily(op->getRegions(), patterns);
   };
 };


        


More information about the Mlir-commits mailing list