[Mlir-commits] [mlir] 6dd696a - [mlir][Linalg] Extend fusion to support WAW atm on buffers.
Hanhan Wang
llvmlistbot at llvm.org
Tue Mar 31 21:34:52 PDT 2020
Author: Hanhan Wang
Date: 2020-03-31T21:33:50-07:00
New Revision: 6dd696ae4fa1b1564e76e5531c268724d2c8b98f
URL: https://github.com/llvm/llvm-project/commit/6dd696ae4fa1b1564e76e5531c268724d2c8b98f
DIFF: https://github.com/llvm/llvm-project/commit/6dd696ae4fa1b1564e76e5531c268724d2c8b98f.diff
LOG: [mlir][Linalg] Extend fusion to support WAW atm on buffers.
Summary:
The RAW fusion happens only if the produecer block dominates the consumer block.
The WAW pattern also works with the precondition. I.e., if a producer can
dominate the consumer, they can fairly fuse together.
Since they are all tilable, we can think the pattern like this way:
Input:
```
linalg_op1 view
tile_loop
subview_2
linalg_op2 subview_2
```
Tile the first Linalg op as same as the second Linalg.
```
tile_loop
subview_1
linalg_op1 subview_1
tile_loop
subview_2
liangl_op2 subview_2
```
Since the first Linalg op is tilable in the same way and the computation are
independently, it's fair to fuse it with the second Linalg op.
```
tile_loop
subview_1
linalg_op1 subview_1
linalg_op2 subview_2
```
In short, this patch includes:
- Handling both RAW and WAW pattern.
- Adding a interface method to get input and output buffers.
- Exposing a method to get a StringRef of a dependency type.
- Fixing existing WAW tests and add one more use case: initialize the buffer
before conv op.
Differential Revision: https://reviews.llvm.org/D76897
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/test/Dialect/Linalg/fusion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index 34de176a998e..e40d63661b77 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -63,6 +63,7 @@ class LinalgDependenceGraph {
using dependence_range = iterator_range<dependence_iterator>;
enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes };
+ static StringRef getDependenceTypeStr(DependenceType depType);
// Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 8fcc1ceea502..46fb9881aba5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -100,6 +100,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
//===------------------------------------------------------------------===//
// Input and Output arguments handling.
//===------------------------------------------------------------------===//
+ InterfaceMethod<
+ "Return one single buffer at position `$i`.",
+ "Value", "getBuffer", (ins "unsigned":$i)
+ >,
InterfaceMethod<
"Return the number of inputs and outputs, irrespective of their buffer "
"or tensor type.",
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index f546d3670b6a..b13b6d268226 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -184,6 +184,10 @@ class StructuredOpTraits
//==========================================================================//
// Input and Output arguments handling.
//==========================================================================//
+ Value getBuffer(unsigned i) {
+ assert(i < getNumInputsAndOutputBuffers() && "overflowing buffers index");
+ return this->getOperation()->getOperand(i);
+ }
/// Return the number of inputs and outputs, irrespective of their buffer or
/// tensor type.
unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index bf52039a6dc1..90ce8fd6bb0b 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -24,24 +24,6 @@ using namespace mlir::linalg;
using llvm::dbgs;
-#ifndef NDEBUG
-static StringRef toStringRef(LinalgDependenceGraph::DependenceType dt) {
- switch (dt) {
- case LinalgDependenceGraph::DependenceType::RAW:
- return "RAW";
- case LinalgDependenceGraph::DependenceType::RAR:
- return "RAR";
- case LinalgDependenceGraph::DependenceType::WAR:
- return "WAR";
- case LinalgDependenceGraph::DependenceType::WAW:
- return "WAW";
- default:
- break;
- }
- llvm_unreachable("Unexpected DependenceType");
-}
-#endif
-
Value Aliases::find(Value v) {
if (v.isa<BlockArgument>())
return v;
@@ -76,6 +58,22 @@ Value Aliases::find(Value v) {
}
}
+StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
+ switch (depType) {
+ case LinalgDependenceGraph::DependenceType::RAW:
+ return "RAW";
+ case LinalgDependenceGraph::DependenceType::RAR:
+ return "RAR";
+ case LinalgDependenceGraph::DependenceType::WAR:
+ return "WAR";
+ case LinalgDependenceGraph::DependenceType::WAW:
+ return "WAW";
+ default:
+ break;
+ }
+ llvm_unreachable("Unexpected DependenceType");
+}
+
LinalgDependenceGraph
LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
SmallVector<Operation *, 8> linalgOps;
@@ -100,7 +98,7 @@ LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
LinalgOpView indexingOpView,
LinalgOpView dependentOpView) {
- LLVM_DEBUG(dbgs() << "\nAdd dep type " << toStringRef(dt) << ":\t"
+ LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t"
<< *indexingOpView.op << " -> " << *dependentOpView.op);
dependencesFromGraphs[dt][indexingOpView.op].push_back(
LinalgDependenceGraphElem{dependentOpView, indexingOpView.view});
@@ -227,8 +225,8 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
continue;
auto *op = dependence.dependentOpView.op;
LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
- << toStringRef(dt) << ": " << *src << " -> " << *op
- << " on " << dependence.indexingView);
+ << getDependenceTypeStr(dt) << ": " << *src << " -> "
+ << *op << " on " << dependence.indexingView);
res.push_back(op);
}
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index b6af16c979c3..4d20bb541e28 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -157,9 +157,9 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
}
auto subView = dyn_cast_or_null<SubViewOp>(
- consumer.getInput(consumerIdx).getDefiningOp());
- auto slice =
- dyn_cast_or_null<SliceOp>(consumer.getInput(consumerIdx).getDefiningOp());
+ consumer.getBuffer(consumerIdx).getDefiningOp());
+ auto slice = dyn_cast_or_null<SliceOp>(
+ consumer.getBuffer(consumerIdx).getDefiningOp());
assert(subView || slice);
(void)subView;
(void)slice;
@@ -274,16 +274,15 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
return true;
}
-// Only consider RAW atm.
-Optional<FusionInfo> mlir::linalg::fuseProducerOf(
- OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
- const LinalgDependenceGraph &graph, OperationFolder *folder) {
+static Optional<FusionInfo>
+fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
+ const LinalgDependenceGraph &graph, OperationFolder *folder,
+ LinalgDependenceGraph::DependenceType depType) {
assert(consumer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
<< *consumer.getOperation());
- for (auto dependence : graph.getDependencesInto(
- consumer, LinalgDependenceGraph::DependenceType::RAW)) {
+ for (auto dependence : graph.getDependencesInto(consumer, depType)) {
LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
<< *dependence.dependentOpView.op << "\n");
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
@@ -294,7 +293,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
// Check that the dependence is indeed on the input `consumerIdx` view.
auto consumedView = dependence.indexingView;
- if (consumer.getInput(consumerIdx) != consumedView)
+ if (consumer.getBuffer(consumerIdx) != consumedView)
continue;
// Consumer consumes this view, `isStructurallyFusableProducer` also checks
@@ -302,9 +301,10 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
auto producedView = dependence.dependentOpView.view;
auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
// `consumerIdx` and `producerIdx` exist by construction.
- LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
- << " view: " << producedView
- << " output index: " << producerIdx);
+ LLVM_DEBUG(dbgs() << "\n"
+ << LinalgDependenceGraph::getDependenceTypeStr(depType)
+ << "producer: " << *producer.getOperation() << " view: "
+ << producedView << " output index: " << producerIdx);
// Must be a subview or a slice to guarantee there are loops we can fuse
// into.
@@ -332,6 +332,22 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
return llvm::None;
}
+// Only consider RAW and WAW atm.
+Optional<FusionInfo> mlir::linalg::fuseProducerOf(
+ OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
+ const LinalgDependenceGraph &graph, OperationFolder *folder) {
+ SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = {
+ LinalgDependenceGraph::DependenceType::RAW,
+ LinalgDependenceGraph::DependenceType::WAW,
+ };
+ for (auto dep : deps) {
+ if (auto res =
+ fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep))
+ return res;
+ }
+ return llvm::None;
+}
+
/// Checks if two Generic ops are fusible, when one is a producer and another is
/// a consumer (with the result of the producer being the `consumerIdx` operand
/// of the consumer).
@@ -498,7 +514,8 @@ static void fuseLinalgOpsGreedily(FuncOp f) {
// The current naive and expensive reconstruction of the graph should be
// removed.
for (auto *op : llvm::reverse(linalgOps)) {
- for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) {
+ for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers();
+ id < e; ++id) {
linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalgOps);
if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index f844f76a3da6..82ef196d0d97 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -41,12 +41,11 @@ func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>,
}
// CHECK-LABEL: func @f1
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// No RAW dependences, the pass does not fuse RAR atm.
-// CHECK: linalg.matmul
// CHECK: loop.for
// CHECK: loop.for
// CHECK: loop.for
// CHECK: linalg.matmul
+// CHECK: linalg.matmul
// -----
@@ -334,15 +333,13 @@ func @f6(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
}
// CHECK-LABEL: func @f6
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// Cannot fuse C due to interleaved read of C that would be bypassed.
-// Cannot fuse E (WAW).
-// CHECK: linalg.matmul
-// CHECK: linalg.matmul
+// Fuse the producer of E (WAW) then the producer of C (WAR).
// CHECK: loop.for
// CHECK: loop.for
// CHECK: loop.for
// CHECK: linalg.matmul
-// CHECK-NOT: linalg.matmul
+// CHECK: linalg.matmul
+// CHECK: linalg.matmul
// -----
@@ -785,3 +782,53 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
// CHECK: linalg.generic
// CHECK: exp
// CHECK: linalg.yield
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
+#map1 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
+#map2 = affine_map<()[s0] -> (s0 + 3)>
+
+func @fill_and_conv(%arg0: memref<1x4x5x1xf32>, %arg1: memref<2x3x1x1xf32>, %arg2: memref<1x4x5x1xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ linalg.fill(%arg2, %cst) : memref<1x4x5x1xf32>, f32
+
+ %c4 = constant 4 : index
+ %c1 = constant 1 : index
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %4 = dim %arg1, 0 : memref<2x3x1x1xf32>
+ %5 = dim %arg1, 1 : memref<2x3x1x1xf32>
+ %6 = dim %arg0, 0 : memref<1x4x5x1xf32>
+ %7 = dim %arg0, 1 : memref<1x4x5x1xf32>
+ %8 = dim %arg0, 3 : memref<1x4x5x1xf32>
+ %9 = dim %arg2, 0 : memref<1x4x5x1xf32>
+ %10 = dim %arg2, 1 : memref<1x4x5x1xf32>
+ %11 = dim %arg2, 2 : memref<1x4x5x1xf32>
+ %12 = dim %arg2, 3 : memref<1x4x5x1xf32>
+ %13 = linalg.range %c0 : %6 : %c2 : !linalg.range
+ %14 = linalg.range %c0 : %10 : %c3 : !linalg.range
+ loop.for %arg3 = %c0 to %6 step %c2 {
+ loop.for %arg4 = %c0 to %10 step %c3 {
+ %15 = affine.min #map0(%c2, %c1, %arg3)
+ %16 = affine.apply #map2()[%7]
+ %17 = affine.min #map0(%16, %c4, %arg4)
+ %18 = dim %arg0, 2 : memref<1x4x5x1xf32>
+ %19 = dim %arg0, 3 : memref<1x4x5x1xf32>
+ %20 = subview %arg0[%arg3, %arg4, %c0, %c0] [%15, %17, %18, %19] [%c1, %c1, %c1, %c1] : memref<1x4x5x1xf32> to memref<?x?x?x?xf32, #map1>
+ %21 = affine.min #map0(%c2, %c1, %arg3)
+ %22 = affine.min #map0(%c3, %c4, %arg4)
+ %23 = dim %arg2, 2 : memref<1x4x5x1xf32>
+ %24 = dim %arg2, 3 : memref<1x4x5x1xf32>
+ %25 = subview %arg2[%arg3, %arg4, %c0, %c0] [%21, %22, %23, %24] [%c1, %c1, %c1, %c1] : memref<1x4x5x1xf32> to memref<?x?x?x?xf32, #map1>
+ linalg.conv(%arg1, %20, %25) {dilations = [1, 1], strides = [1, 1]} : memref<2x3x1x1xf32>, memref<?x?x?x?xf32, #map1>, memref<?x?x?x?xf32, #map1>
+ }
+ }
+ return
+}
+// CHECK-LABEL: func @fill_and_conv
+// CHECK: loop.for
+// CHECK: loop.for
+// CHECK: linalg.fill
+// CHECK: linalg.conv
More information about the Mlir-commits
mailing list