[Mlir-commits] [mlir] 8cf650c - [mlir][linalg] Add support for WAW fusion on tensors.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Apr 16 01:27:11 PDT 2021


Author: Nicolas Vasilache
Date: 2021-04-16T08:22:09Z
New Revision: 8cf650c554441935f232d10725403da5038f597e

URL: https://github.com/llvm/llvm-project/commit/8cf650c554441935f232d10725403da5038f597e
DIFF: https://github.com/llvm/llvm-project/commit/8cf650c554441935f232d10725403da5038f597e.diff

LOG: [mlir][linalg] Add support for WAW fusion on tensors.

Differential Revision: https://reviews.llvm.org/D100603

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
    mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index fecaeff1c8df4..17fa57d341ca5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -229,6 +229,10 @@ class LinalgDependenceGraph {
                          ArrayRef<DependenceType> depTypes = {
                              DependenceType::RAW, DependenceType::WAW}) const;
 
+  void print(raw_ostream &os) const;
+
+  void dump() const;
+
 private:
   // Keep dependences in both directions, this is not just a performance gain
   // but it also reduces usage errors.

diff  --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index 4b92667b0cc01..3e37979b68ec2 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -162,6 +162,8 @@ LinalgDependenceGraph::getDependencesInto(
 }
 
 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
+  LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation()
+                    << " and " << *dst.getOperation() << "\n");
   if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
     for (OpOperand &dstOpOperand : dst.getInputOpOperands()) {
       // Check if the operand is defined by the src.
@@ -170,6 +172,18 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
         addDependenceElem(DependenceType::RAW, dstOpOperand.get(),
                           &dstOpOperand);
     }
+    for (OpOperand &dstOpOperand : dst.getOutputOpOperands()) {
+      // Check if the operand is defined by the src.
+      auto definingOp = dstOpOperand.get().getDefiningOp<LinalgOp>();
+      if (definingOp && definingOp == src) {
+        if (dst.isInitTensor(&dstOpOperand)) {
+          addDependenceElem(DependenceType::RAW, dstOpOperand.get(),
+                            &dstOpOperand);
+        }
+        addDependenceElem(DependenceType::WAW, dstOpOperand.get(),
+                          &dstOpOperand);
+      }
+    }
     return;
   }
   assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
@@ -322,3 +336,21 @@ LinalgDependenceGraph::getDependentOperations(
   dependentOperations.append(t.begin(), t.end());
   return dependentOperations;
 }
+
+void LinalgDependenceGraph::print(raw_ostream &os) const {
+  for (auto dt : {
+           LinalgDependenceGraph::DependenceType::RAW,
+           LinalgDependenceGraph::DependenceType::WAW,
+       }) {
+    const auto &fromGraph = dependencesFromGraphs[dt];
+    for (const auto &it : fromGraph) {
+      os << "[LinalgDependenceGraph] DT " << dt << " from: " << *it.first
+         << ":\n";
+      for (const auto &dep : it.second) {
+        os << "\tDT " << dt << " " << *dep.getDependentOp() << ":\n";
+      }
+    }
+  }
+}
+
+void LinalgDependenceGraph::dump() const { print(llvm::errs()); }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 247e4e5c26bb5..e8759ee882675 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -29,6 +29,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 
@@ -331,6 +332,10 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
 findFusableProducer(OpOperand &consumerOpOperand,
                     const LinalgDependenceGraph &dependenceGraph) {
+  LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: "
+                          << consumerOpOperand.get() << " @"
+                          << consumerOpOperand.getOperandNumber() << " in "
+                          << *consumerOpOperand.getOwner() << "\n");
   LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
   if (!consumerOp)
     return {};
@@ -340,9 +345,14 @@ findFusableProducer(OpOperand &consumerOpOperand,
            LinalgDependenceGraph::DependenceType::RAW,
            LinalgDependenceGraph::DependenceType::WAW,
        }) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Dependencies into: " << *consumerOp.getOperation() << "\n");
     for (auto dependence : llvm::make_filter_range(
              dependenceGraph.getDependencesInto(consumerOp, depType),
              [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
+               LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: "
+                                       << elem.getIndexingValue() << " and "
+                                       << elem.getDependentValue() << "\n");
                Value v = elem.getIndexingValue();
                Optional<unsigned> operandNum =
                    elem.getIndexingOpViewOperandNum();
@@ -783,12 +793,14 @@ static Optional<TiledLinalgOp> tileRootOperation(
 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
 /// `tiledOp`.
 static SmallVector<LinalgOp, 1>
-fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp,
+fuseOperations(OpBuilder &builder, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
                ArrayRef<LinalgOp> fusionCandidates,
                const FusableOpDependencesTy &fusableDependences,
                const std::set<unsigned> &fusedLoops) {
+  LinalgOp tiledOp = tiledLinalgOp.op;
   OpBuilder::InsertionGuard guard(builder);
   builder.setInsertionPoint(tiledOp);
+
   DenseMap<unsigned, Range> fusedLoopsAndRanges;
   for (unsigned loop : fusedLoops) {
     ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
@@ -804,27 +816,49 @@ fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp,
     LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges);
     origOpToFusedOp[origOp.getOperation()] = fusedOp;
     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
+
+    // Prepare the builder for the next insertion point.
+    auto guard =
+        llvm::make_scope_exit([&]() { builder.setInsertionPoint(fusedOp); });
+    if (!origOp.hasTensorSemantics())
+      continue;
+
     // If the producer consumer operations are linalg operations on tensors, the
     // dependence is due to value produced (as a return tensor) by the producer
     // and used in the consumer. The returned value of the fused op needs to be
     // made the operand of the tiled/fused consumer operation. By construction
     // the value returned by the producer is the value used by the consumer.
     for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) {
-      if (origOp.hasTensorSemantics() &&
-          dependence.dependenceType ==
-              LinalgDependenceGraph::DependenceType::RAW) {
-        unsigned resultIndex =
-            dependence.getDependentOpViewResultNum().getValue();
-        LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
-        if (!consumer)
-          continue;
-        Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
-        consumer.getOperation()->setOperand(
-            dependence.getIndexingOpViewOperandNum().getValue(),
-            replacementValue);
-      }
+      if (dependence.dependenceType !=
+          LinalgDependenceGraph::DependenceType::RAW)
+        continue;
+
+      unsigned resultIndex =
+          dependence.getDependentOpViewResultNum().getValue();
+      LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
+      if (!consumer)
+        continue;
+
+      Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
+      consumer.getOperation()->setOperand(
+          dependence.getIndexingOpViewOperandNum().getValue(),
+          replacementValue);
     }
-    builder.setInsertionPoint(fusedOp);
+
+    // At this point, all Linalg uses of the tensors produced by `origOp` have
+    // been replaced. However, there may still be "output tensor"-like uses
+    // coming from WAW dependencies.
+    // All these uses are iter_args of the outermost loop (TODO: add a check).
+    // Such iter_args uses serve 2 purposes:
+    //  1. give a shape to the output
+    //  2. encode destructive updates that may be inplaceable by bufferization.
+    // To keep the second type of information while letting the unfused op die
+    // unused, we need to forward the producer output operand.
+    for (auto &operand :
+         cast<scf::ForOp>(tiledLinalgOp.loops.front()).getIterOpOperands())
+      if (auto opResult = operand.get().dyn_cast<OpResult>())
+        if (opResult.getOwner() == origOp)
+          operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
   }
   return fusedOps;
 }
@@ -860,18 +894,23 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
   ScopedContext scope(builder, rootOp.getLoc());
 
   // Find all the producers.
+  LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n");
   FusableOpDependencesTy fusableDependences =
       findAllFusableDependences(ops, dependenceGraph);
-  if (fusableDependences.empty())
+  if (fusableDependences.empty()) {
+    LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n");
     return llvm::None;
+  }
 
   TiledAndFusedLinalgOps ret;
   // Find the loops that can be tiled and fused.
+  LLVM_DEBUG(llvm::dbgs() << "collectFusableLoops\n");
   ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
 
   // If there are no fusable dependences or there are no tile+fusable loops,
   // just return.
   if (ret.fusedLoopDims.empty()) {
+    LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n");
     return llvm::None;
   }
 
@@ -888,8 +927,9 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
   ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
 
   // Fuse the other operations into the fused inter-tile loops produced above.
-  ret.fusedProducers = fuseOperations(builder, rootOp, ret.op, ops.drop_back(),
-                                      fusableDependences, ret.fusedLoopDims);
+  ret.fusedProducers =
+      fuseOperations(builder, rootOp, *tiledRootOp, ops.drop_back(),
+                     fusableDependences, ret.fusedLoopDims);
 
   return ret;
 }

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index b24324bba110d..3775b67c43548 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -143,3 +143,34 @@ module {
 //       CHECK:       scf.yield %[[UPDATE]]
 //       CHECK:     scf.yield %[[YIELD]]
 //       CHECK:   return %[[RESULT]]
+
+// -----
+
+module {
+  func @matmul_out_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+                      %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    %c0 = constant 0.0 : f32
+    %0 = linalg.fill(%arg0, %c0) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+    %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"}
+      ins(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+    return %1 : tensor<?x?xf32>
+  }
+}
+
+// CHECK-LABEL: func @matmul_out_fusion(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//       CHECK: %[[C0:.*]] = constant 0.0{{.*}} : f32
+//   CHECK-NOT: fill
+//       CHECK: scf.for %[[I:.*]]{{.*}}iter_args(%{{.*}} = %[[ARG0]]) -> (tensor<?x?xf32>) {
+//       CHECK:   scf.for %[[J:.*]]
+//       CHECK:     %[[ST:.*]] = subtensor %[[ARG0]]
+//       CHECK:     %[[ST_FILL:.*]] = linalg.fill(%[[ST]], %[[C0]]) {__internal_linalg_transform__ = "after_out_fusion_producer"} : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+//       CHECK:     %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor<?x?xf32>) {
+//   CHECK-NOT:       fill
+//       CHECK:       %[[ST_MM:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[BB]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+//       CHECK:       scf.yield %[[ST_MM]] : tensor<?x?xf32>
+//       CHECK:     %[[MM:.*]] = subtensor_insert %[[ST_MM_RES]] into {{.*}}
+//       CHECK:     scf.yield %[[MM]] : tensor<?x?xf32>

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 3ef6ed5e4b4ba..3e67774ba13a4 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -52,6 +52,19 @@ static void fillFusionPatterns(MLIRContext *context,
           ArrayRef<Identifier>(),
           Identifier::get("after_lhs_fusion_original", context)));
 
+  patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+      LinalgFusionOptions().setIndicesToFuse({2}),
+      LinalgTransformationFilter(Identifier::get("out_fusion", context),
+                                 Identifier::get("after_out_fusion", context)),
+      LinalgTransformationFilter(
+          ArrayRef<Identifier>(),
+          Identifier::get("after_out_fusion_producer", context)),
+      LinalgTransformationFilter(
+          ArrayRef<Identifier>(),
+          Identifier::get("after_out_fusion_original", context)));
+
   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
       context, dependenceGraph,
       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),


        


More information about the Mlir-commits mailing list