[Mlir-commits] [mlir] 6dd696a - [mlir][Linalg] Extend fusion to support WAW atm on buffers.

Hanhan Wang llvmlistbot at llvm.org
Tue Mar 31 21:34:52 PDT 2020


Author: Hanhan Wang
Date: 2020-03-31T21:33:50-07:00
New Revision: 6dd696ae4fa1b1564e76e5531c268724d2c8b98f

URL: https://github.com/llvm/llvm-project/commit/6dd696ae4fa1b1564e76e5531c268724d2c8b98f
DIFF: https://github.com/llvm/llvm-project/commit/6dd696ae4fa1b1564e76e5531c268724d2c8b98f.diff

LOG: [mlir][Linalg] Extend fusion to support WAW atm on buffers.

Summary:
The RAW fusion happens only if the produecer block dominates the consumer block.
The WAW pattern also works with the precondition. I.e., if a producer can
dominate the consumer, they can fairly fuse together.

Since they are all tilable, we can think the pattern like this way:

Input:
```
linalg_op1 view

tile_loop
  subview_2
  linalg_op2 subview_2
```

Tile the first Linalg op as same as the second Linalg.
```
tile_loop
  subview_1
  linalg_op1 subview_1

tile_loop
  subview_2
  liangl_op2 subview_2
```

Since the first Linalg op is tilable in the same way and the computation are
independently, it's fair to fuse it with the second Linalg op.
```
tile_loop
  subview_1
  linalg_op1 subview_1
  linalg_op2 subview_2
```

In short, this patch includes:
- Handling both RAW and WAW pattern.
- Adding a interface method to get input and output buffers.
- Exposing a method to get a StringRef of a dependency type.
- Fixing existing WAW tests and add one more use case: initialize the buffer
  before conv op.

Differential Revision: https://reviews.llvm.org/D76897

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
    mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/test/Dialect/Linalg/fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index 34de176a998e..e40d63661b77 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -63,6 +63,7 @@ class LinalgDependenceGraph {
   using dependence_range = iterator_range<dependence_iterator>;
 
   enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes };
+  static StringRef getDependenceTypeStr(DependenceType depType);
 
   // Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
   static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f);

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 8fcc1ceea502..46fb9881aba5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -100,6 +100,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     //===------------------------------------------------------------------===//
     // Input and Output arguments handling.
     //===------------------------------------------------------------------===//
+    InterfaceMethod<
+      "Return one single buffer at position `$i`.",
+      "Value", "getBuffer", (ins "unsigned":$i)
+    >,
     InterfaceMethod<
       "Return the number of inputs and outputs, irrespective of their buffer "
       "or tensor type.",

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index f546d3670b6a..b13b6d268226 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -184,6 +184,10 @@ class StructuredOpTraits
   //==========================================================================//
   // Input and Output arguments handling.
   //==========================================================================//
+  Value getBuffer(unsigned i) {
+    assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index");
+    return this->getOperation()->getOperand(i);
+  }
   /// Return the number of inputs and outputs, irrespective of their buffer or
   /// tensor type.
   unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }

diff  --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index bf52039a6dc1..90ce8fd6bb0b 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -24,24 +24,6 @@ using namespace mlir::linalg;
 
 using llvm::dbgs;
 
-#ifndef NDEBUG
-static StringRef toStringRef(LinalgDependenceGraph::DependenceType dt) {
-  switch (dt) {
-  case LinalgDependenceGraph::DependenceType::RAW:
-    return "RAW";
-  case LinalgDependenceGraph::DependenceType::RAR:
-    return "RAR";
-  case LinalgDependenceGraph::DependenceType::WAR:
-    return "WAR";
-  case LinalgDependenceGraph::DependenceType::WAW:
-    return "WAW";
-  default:
-    break;
-  }
-  llvm_unreachable("Unexpected DependenceType");
-}
-#endif
-
 Value Aliases::find(Value v) {
   if (v.isa<BlockArgument>())
     return v;
@@ -76,6 +58,22 @@ Value Aliases::find(Value v) {
   }
 }
 
+StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
+  switch (depType) {
+  case LinalgDependenceGraph::DependenceType::RAW:
+    return "RAW";
+  case LinalgDependenceGraph::DependenceType::RAR:
+    return "RAR";
+  case LinalgDependenceGraph::DependenceType::WAR:
+    return "WAR";
+  case LinalgDependenceGraph::DependenceType::WAW:
+    return "WAW";
+  default:
+    break;
+  }
+  llvm_unreachable("Unexpected DependenceType");
+}
+
 LinalgDependenceGraph
 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
   SmallVector<Operation *, 8> linalgOps;
@@ -100,7 +98,7 @@ LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
 void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
                                               LinalgOpView indexingOpView,
                                               LinalgOpView dependentOpView) {
-  LLVM_DEBUG(dbgs() << "\nAdd dep type " << toStringRef(dt) << ":\t"
+  LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t"
                     << *indexingOpView.op << " -> " << *dependentOpView.op);
   dependencesFromGraphs[dt][indexingOpView.op].push_back(
       LinalgDependenceGraphElem{dependentOpView, indexingOpView.view});
@@ -227,8 +225,8 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
         continue;
       auto *op = dependence.dependentOpView.op;
       LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
-                        << toStringRef(dt) << ": " << *src << " -> " << *op
-                        << " on " << dependence.indexingView);
+                        << getDependenceTypeStr(dt) << ": " << *src << " -> "
+                        << *op << " on " << dependence.indexingView);
       res.push_back(op);
     }
   }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index b6af16c979c3..4d20bb541e28 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -157,9 +157,9 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
   }
 
   auto subView = dyn_cast_or_null<SubViewOp>(
-      consumer.getInput(consumerIdx).getDefiningOp());
-  auto slice =
-      dyn_cast_or_null<SliceOp>(consumer.getInput(consumerIdx).getDefiningOp());
+      consumer.getBuffer(consumerIdx).getDefiningOp());
+  auto slice = dyn_cast_or_null<SliceOp>(
+      consumer.getBuffer(consumerIdx).getDefiningOp());
   assert(subView || slice);
   (void)subView;
   (void)slice;
@@ -274,16 +274,15 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
   return true;
 }
 
-// Only consider RAW atm.
-Optional<FusionInfo> mlir::linalg::fuseProducerOf(
-    OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
-    const LinalgDependenceGraph &graph, OperationFolder *folder) {
+static Optional<FusionInfo>
+fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
+                  const LinalgDependenceGraph &graph, OperationFolder *folder,
+                  LinalgDependenceGraph::DependenceType depType) {
   assert(consumer.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
                     << *consumer.getOperation());
-  for (auto dependence : graph.getDependencesInto(
-           consumer, LinalgDependenceGraph::DependenceType::RAW)) {
+  for (auto dependence : graph.getDependencesInto(consumer, depType)) {
     LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
                       << *dependence.dependentOpView.op << "\n");
     auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
@@ -294,7 +293,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
 
     // Check that the dependence is indeed on the input `consumerIdx` view.
     auto consumedView = dependence.indexingView;
-    if (consumer.getInput(consumerIdx) != consumedView)
+    if (consumer.getBuffer(consumerIdx) != consumedView)
       continue;
 
     // Consumer consumes this view, `isStructurallyFusableProducer` also checks
@@ -302,9 +301,10 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
     auto producedView = dependence.dependentOpView.view;
     auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
     // `consumerIdx` and `producerIdx` exist by construction.
-    LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
-                      << " view: " << producedView
-                      << " output index: " << producerIdx);
+    LLVM_DEBUG(dbgs() << "\n"
+                      << LinalgDependenceGraph::getDependenceTypeStr(depType)
+                      << "producer: " << *producer.getOperation() << " view: "
+                      << producedView << " output index: " << producerIdx);
 
     // Must be a subview or a slice to guarantee there are loops we can fuse
     // into.
@@ -332,6 +332,22 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
   return llvm::None;
 }
 
+// Only consider RAW and WAW atm.
+Optional<FusionInfo> mlir::linalg::fuseProducerOf(
+    OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
+    const LinalgDependenceGraph &graph, OperationFolder *folder) {
+  SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = {
+      LinalgDependenceGraph::DependenceType::RAW,
+      LinalgDependenceGraph::DependenceType::WAW,
+  };
+  for (auto dep : deps) {
+    if (auto res =
+            fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep))
+      return res;
+  }
+  return llvm::None;
+}
+
 /// Checks if two Generic ops are fusible, when one is a producer and another is
 /// a consumer (with the result of the producer being the `consumerIdx` operand
 /// of the consumer).
@@ -498,7 +514,8 @@ static void fuseLinalgOpsGreedily(FuncOp f) {
   // The current naive and expensive reconstruction of the graph should be
   // removed.
   for (auto *op : llvm::reverse(linalgOps)) {
-    for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) {
+    for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers();
+         id < e; ++id) {
       linalg::Aliases aliases;
       linalg::LinalgDependenceGraph graph(aliases, linalgOps);
       if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {

diff  --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index f844f76a3da6..82ef196d0d97 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -41,12 +41,11 @@ func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>,
 }
 // CHECK-LABEL: func @f1
 // CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// No RAW dependences, the pass does not fuse RAR atm.
-// CHECK: linalg.matmul
 // CHECK: loop.for
 // CHECK:   loop.for
 // CHECK:     loop.for
 // CHECK:       linalg.matmul
+// CHECK:       linalg.matmul
 
 // -----
 
@@ -334,15 +333,13 @@ func @f6(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
 }
 // CHECK-LABEL: func @f6
 // CHECK:  (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// Cannot fuse C due to interleaved read of C that would be bypassed.
-// Cannot fuse E (WAW).
-// CHECK:  linalg.matmul
-// CHECK:  linalg.matmul
+// Fuse the producer of E (WAW) then the producer of C (WAR).
 // CHECK:  loop.for
 // CHECK:    loop.for
 // CHECK:      loop.for
 // CHECK:        linalg.matmul
-// CHECK-NOT:      linalg.matmul
+// CHECK:        linalg.matmul
+// CHECK:        linalg.matmul
 
 // -----
 
@@ -785,3 +782,53 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
 // CHECK:       linalg.generic
 // CHECK:         exp
 // CHECK:         linalg.yield
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
+#map1 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
+#map2 = affine_map<()[s0] -> (s0 + 3)>
+
+func @fill_and_conv(%arg0: memref<1x4x5x1xf32>, %arg1: memref<2x3x1x1xf32>, %arg2: memref<1x4x5x1xf32>) {
+  %cst = constant 0.000000e+00 : f32
+  linalg.fill(%arg2, %cst) : memref<1x4x5x1xf32>, f32
+
+  %c4 = constant 4 : index
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %4 = dim %arg1, 0 : memref<2x3x1x1xf32>
+  %5 = dim %arg1, 1 : memref<2x3x1x1xf32>
+  %6 = dim %arg0, 0 : memref<1x4x5x1xf32>
+  %7 = dim %arg0, 1 : memref<1x4x5x1xf32>
+  %8 = dim %arg0, 3 : memref<1x4x5x1xf32>
+  %9 = dim %arg2, 0 : memref<1x4x5x1xf32>
+  %10 = dim %arg2, 1 : memref<1x4x5x1xf32>
+  %11 = dim %arg2, 2 : memref<1x4x5x1xf32>
+  %12 = dim %arg2, 3 : memref<1x4x5x1xf32>
+  %13 = linalg.range %c0 : %6 : %c2 : !linalg.range
+  %14 = linalg.range %c0 : %10 : %c3 : !linalg.range
+  loop.for %arg3 = %c0 to %6 step %c2 {
+    loop.for %arg4 = %c0 to %10 step %c3 {
+      %15 = affine.min #map0(%c2, %c1, %arg3)
+      %16 = affine.apply #map2()[%7]
+      %17 = affine.min #map0(%16, %c4, %arg4)
+      %18 = dim %arg0, 2 : memref<1x4x5x1xf32>
+      %19 = dim %arg0, 3 : memref<1x4x5x1xf32>
+      %20 = subview %arg0[%arg3, %arg4, %c0, %c0] [%15, %17, %18, %19] [%c1, %c1, %c1, %c1] : memref<1x4x5x1xf32> to memref<?x?x?x?xf32, #map1>
+      %21 = affine.min #map0(%c2, %c1, %arg3)
+      %22 = affine.min #map0(%c3, %c4, %arg4)
+      %23 = dim %arg2, 2 : memref<1x4x5x1xf32>
+      %24 = dim %arg2, 3 : memref<1x4x5x1xf32>
+      %25 = subview %arg2[%arg3, %arg4, %c0, %c0] [%21, %22, %23, %24] [%c1, %c1, %c1, %c1] : memref<1x4x5x1xf32> to memref<?x?x?x?xf32, #map1>
+      linalg.conv(%arg1, %20, %25) {dilations = [1, 1], strides = [1, 1]} : memref<2x3x1x1xf32>, memref<?x?x?x?xf32, #map1>, memref<?x?x?x?xf32, #map1>
+    }
+  }
+  return
+}
+// CHECK-LABEL: func @fill_and_conv
+// CHECK: loop.for
+// CHECK:   loop.for
+// CHECK:     linalg.fill
+// CHECK:     linalg.conv


        


More information about the Mlir-commits mailing list