[Mlir-commits] [mlir] 8cf650c - [mlir][linalg] Add support for WAW fusion on tensors.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Apr 16 01:27:11 PDT 2021
Author: Nicolas Vasilache
Date: 2021-04-16T08:22:09Z
New Revision: 8cf650c554441935f232d10725403da5038f597e
URL: https://github.com/llvm/llvm-project/commit/8cf650c554441935f232d10725403da5038f597e
DIFF: https://github.com/llvm/llvm-project/commit/8cf650c554441935f232d10725403da5038f597e.diff
LOG: [mlir][linalg] Add support for WAW fusion on tensors.
Differential Revision: https://reviews.llvm.org/D100603
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index fecaeff1c8df4..17fa57d341ca5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -229,6 +229,10 @@ class LinalgDependenceGraph {
ArrayRef<DependenceType> depTypes = {
DependenceType::RAW, DependenceType::WAW}) const;
+ void print(raw_ostream &os) const;
+
+ void dump() const;
+
private:
// Keep dependences in both directions, this is not just a performance gain
// but it also reduces usage errors.
diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index 4b92667b0cc01..3e37979b68ec2 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -162,6 +162,8 @@ LinalgDependenceGraph::getDependencesInto(
}
void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
+ LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation()
+ << " and " << *dst.getOperation() << "\n");
if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
for (OpOperand &dstOpOperand : dst.getInputOpOperands()) {
// Check if the operand is defined by the src.
@@ -170,6 +172,18 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
addDependenceElem(DependenceType::RAW, dstOpOperand.get(),
&dstOpOperand);
}
+ for (OpOperand &dstOpOperand : dst.getOutputOpOperands()) {
+ // Check if the operand is defined by the src.
+ auto definingOp = dstOpOperand.get().getDefiningOp<LinalgOp>();
+ if (definingOp && definingOp == src) {
+ if (dst.isInitTensor(&dstOpOperand)) {
+ addDependenceElem(DependenceType::RAW, dstOpOperand.get(),
+ &dstOpOperand);
+ }
+ addDependenceElem(DependenceType::WAW, dstOpOperand.get(),
+ &dstOpOperand);
+ }
+ }
return;
}
assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
@@ -322,3 +336,21 @@ LinalgDependenceGraph::getDependentOperations(
dependentOperations.append(t.begin(), t.end());
return dependentOperations;
}
+
+void LinalgDependenceGraph::print(raw_ostream &os) const {
+ for (auto dt : {
+ LinalgDependenceGraph::DependenceType::RAW,
+ LinalgDependenceGraph::DependenceType::WAW,
+ }) {
+ const auto &fromGraph = dependencesFromGraphs[dt];
+ for (const auto &it : fromGraph) {
+ os << "[LinalgDependenceGraph] DT " << dt << " from: " << *it.first
+ << ":\n";
+ for (const auto &dep : it.second) {
+ os << "\tDT " << dt << " " << *dep.getDependentOp() << ":\n";
+ }
+ }
+ }
+}
+
+void LinalgDependenceGraph::dump() const { print(llvm::errs()); }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 247e4e5c26bb5..e8759ee882675 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -29,6 +29,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -331,6 +332,10 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
findFusableProducer(OpOperand &consumerOpOperand,
const LinalgDependenceGraph &dependenceGraph) {
+ LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: "
+ << consumerOpOperand.get() << " @"
+ << consumerOpOperand.getOperandNumber() << " in "
+ << *consumerOpOperand.getOwner() << "\n");
LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
if (!consumerOp)
return {};
@@ -340,9 +345,14 @@ findFusableProducer(OpOperand &consumerOpOperand,
LinalgDependenceGraph::DependenceType::RAW,
LinalgDependenceGraph::DependenceType::WAW,
}) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Dependencies into: " << *consumerOp.getOperation() << "\n");
for (auto dependence : llvm::make_filter_range(
dependenceGraph.getDependencesInto(consumerOp, depType),
[&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
+ LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: "
+ << elem.getIndexingValue() << " and "
+ << elem.getDependentValue() << "\n");
Value v = elem.getIndexingValue();
Optional<unsigned> operandNum =
elem.getIndexingOpViewOperandNum();
@@ -783,12 +793,14 @@ static Optional<TiledLinalgOp> tileRootOperation(
/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
/// `tiledOp`.
static SmallVector<LinalgOp, 1>
-fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp,
+fuseOperations(OpBuilder &builder, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
ArrayRef<LinalgOp> fusionCandidates,
const FusableOpDependencesTy &fusableDependences,
const std::set<unsigned> &fusedLoops) {
+ LinalgOp tiledOp = tiledLinalgOp.op;
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(tiledOp);
+
DenseMap<unsigned, Range> fusedLoopsAndRanges;
for (unsigned loop : fusedLoops) {
ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
@@ -804,27 +816,49 @@ fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp,
LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges);
origOpToFusedOp[origOp.getOperation()] = fusedOp;
fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
+
+ // Prepare the builder for the next insertion point.
+ auto guard =
+ llvm::make_scope_exit([&]() { builder.setInsertionPoint(fusedOp); });
+ if (!origOp.hasTensorSemantics())
+ continue;
+
// If the producer consumer operations are linalg operations on tensors, the
// dependence is due to value produced (as a return tensor) by the producer
// and used in the consumer. The returned value of the fused op needs to be
// made the operand of the tiled/fused consumer operation. By construction
// the value returned by the producer is the value used by the consumer.
for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) {
- if (origOp.hasTensorSemantics() &&
- dependence.dependenceType ==
- LinalgDependenceGraph::DependenceType::RAW) {
- unsigned resultIndex =
- dependence.getDependentOpViewResultNum().getValue();
- LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
- if (!consumer)
- continue;
- Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
- consumer.getOperation()->setOperand(
- dependence.getIndexingOpViewOperandNum().getValue(),
- replacementValue);
- }
+ if (dependence.dependenceType !=
+ LinalgDependenceGraph::DependenceType::RAW)
+ continue;
+
+ unsigned resultIndex =
+ dependence.getDependentOpViewResultNum().getValue();
+ LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
+ if (!consumer)
+ continue;
+
+ Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
+ consumer.getOperation()->setOperand(
+ dependence.getIndexingOpViewOperandNum().getValue(),
+ replacementValue);
}
- builder.setInsertionPoint(fusedOp);
+
+ // At this point, all Linalg uses of the tensors produced by `origOp` have
+ // been replaced. However, there may still be "output tensor"-like uses
+ // coming from WAW dependencies.
+ // All these uses are iter_args of the outermost loop (TODO: add a check).
+ // Such iter_args uses serve 2 purposes:
+ // 1. give a shape to the output
+ // 2. encode destructive updates that may be inplaceable by bufferization.
+ // To keep the second type of information while letting the unfused op die
+ // unused, we need to forward the producer output operand.
+ for (auto &operand :
+ cast<scf::ForOp>(tiledLinalgOp.loops.front()).getIterOpOperands())
+ if (auto opResult = operand.get().dyn_cast<OpResult>())
+ if (opResult.getOwner() == origOp)
+ operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
}
return fusedOps;
}
@@ -860,18 +894,23 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
ScopedContext scope(builder, rootOp.getLoc());
// Find all the producers.
+ LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n");
FusableOpDependencesTy fusableDependences =
findAllFusableDependences(ops, dependenceGraph);
- if (fusableDependences.empty())
+ if (fusableDependences.empty()) {
+ LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n");
return llvm::None;
+ }
TiledAndFusedLinalgOps ret;
// Find the loops that can be tiled and fused.
+ LLVM_DEBUG(llvm::dbgs() << "collectFusableLoops\n");
ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
// If there are no fusable dependences or there are no tile+fusable loops,
// just return.
if (ret.fusedLoopDims.empty()) {
+ LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n");
return llvm::None;
}
@@ -888,8 +927,9 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
// Fuse the other operations into the fused inter-tile loops produced above.
- ret.fusedProducers = fuseOperations(builder, rootOp, ret.op, ops.drop_back(),
- fusableDependences, ret.fusedLoopDims);
+ ret.fusedProducers =
+ fuseOperations(builder, rootOp, *tiledRootOp, ops.drop_back(),
+ fusableDependences, ret.fusedLoopDims);
return ret;
}
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index b24324bba110d..3775b67c43548 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -143,3 +143,34 @@ module {
// CHECK: scf.yield %[[UPDATE]]
// CHECK: scf.yield %[[YIELD]]
// CHECK: return %[[RESULT]]
+
+// -----
+
+module {
+ func @matmul_out_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = constant 0.0 : f32
+ %0 = linalg.fill(%arg0, %c0) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+ %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"}
+ ins(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+ }
+}
+
+// CHECK-LABEL: func @matmul_out_fusion(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[C0:.*]] = constant 0.0{{.*}} : f32
+// CHECK-NOT: fill
+// CHECK: scf.for %[[I:.*]]{{.*}}iter_args(%{{.*}} = %[[ARG0]]) -> (tensor<?x?xf32>) {
+// CHECK: scf.for %[[J:.*]]
+// CHECK: %[[ST:.*]] = subtensor %[[ARG0]]
+// CHECK: %[[ST_FILL:.*]] = linalg.fill(%[[ST]], %[[C0]]) {__internal_linalg_transform__ = "after_out_fusion_producer"} : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+// CHECK: %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor<?x?xf32>) {
+// CHECK-NOT: fill
+// CHECK: %[[ST_MM:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[BB]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: scf.yield %[[ST_MM]] : tensor<?x?xf32>
+// CHECK: %[[MM:.*]] = subtensor_insert %[[ST_MM_RES]] into {{.*}}
+// CHECK: scf.yield %[[MM]] : tensor<?x?xf32>
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 3ef6ed5e4b4ba..3e67774ba13a4 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -52,6 +52,19 @@ static void fillFusionPatterns(MLIRContext *context,
ArrayRef<Identifier>(),
Identifier::get("after_lhs_fusion_original", context)));
+ patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
+ context, dependenceGraph,
+ LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+ LinalgFusionOptions().setIndicesToFuse({2}),
+ LinalgTransformationFilter(Identifier::get("out_fusion", context),
+ Identifier::get("after_out_fusion", context)),
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>(),
+ Identifier::get("after_out_fusion_producer", context)),
+ LinalgTransformationFilter(
+ ArrayRef<Identifier>(),
+ Identifier::get("after_out_fusion_original", context)));
+
patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
More information about the Mlir-commits
mailing list