[Mlir-commits] [mlir] f8284d2 - [mlir][Linalg] Fuse sequence of Linalg operation (on buffers)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 19 19:03:22 PST 2020
Author: MaheshRavishankar
Date: 2020-11-19T19:03:06-08:00
New Revision: f8284d21a8e294d58a0acd4b8b2e906d7a9f110c
URL: https://github.com/llvm/llvm-project/commit/f8284d21a8e294d58a0acd4b8b2e906d7a9f110c
DIFF: https://github.com/llvm/llvm-project/commit/f8284d21a8e294d58a0acd4b8b2e906d7a9f110c.diff
LOG: [mlir][Linalg] Fuse sequence of Linalg operation (on buffers)
Enhance the tile+fuse logic to allow fusing a sequence of operations.
Differential Revision: https://reviews.llvm.org/D90991
Added:
mlir/test/Dialect/Linalg/fusion-sequence.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/fusion-pattern.mlir
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8d531a1e343a..fac91ca0b256 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -37,14 +37,6 @@ struct TiledLinalgOp {
SmallVector<Value, 4> tensorResults;
};
-struct TiledAndFusedLinalgOps {
- LinalgOp op;
- SmallVector<LinalgOp, 1> fusedProducers;
- SmallVector<LinalgOp, 1> originalProducers;
- SmallVector<Operation *, 4> fusedLoops;
- SmallVector<Operation *, 4> unfusedLoops;
-};
-
/// Populates patterns for vectorization of all ConvN-D ops.
void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@@ -73,14 +65,11 @@ void populateLinalgBufferizePatterns(MLIRContext *context,
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
const LinalgTilingOptions &options);
-/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in
-/// three steps
-/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile
-/// + fuse loops).
-/// - Tile just these loops of the consumer (root operation) and fuse with
-/// the producer.
-/// - Tile again the tiled consumer operation produced above to do rest of
-/// the tiling specified by the `tilingOptions`.
+/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
+/// proceeds as follows:
+/// - Find outer parallel loops in these ops that can be fused.
+/// - Tile fusable outer parallel loops of the last operation in the sequence.
+/// - Fuse the remaining operations with the tiled operation
///
/// For example, consider the sequence of matmul below
///
@@ -107,36 +96,39 @@ Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
/// : memref<256x32xf32> to memref<16x32xf32, #map0>
/// %3 = subview %arg1[0, 0] [32, 32] [1, 1]
/// : memref<32x32xf32> to memref<32x32xf32, #map1>
+/// %4 = subview %arg3[0, 0] [32, 32] [1, 1]
+/// : memref<32x32xf32> to memref<32x32xf32, #map1>
/// linalg.matmul
/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
/// outs(%0 : memref<16x32xf32, #map0>)
-/// scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) {
-/// scf.for %arg7 = %c0 to %c32 step %c4 {
-/// %4 = subview %0[0, %arg7] [16, 4] [1, 1]
-/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0>
-/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1]
-/// : memref<32x32xf32> to memref<4x8xf32, #map0>
-/// %6 = subview %1[0, %arg6] [16, 8] [1, 1]
-/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0>
-/// linalg.matmul
-/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
-/// outs(%6 : memref<16x8xf32, #map0>)
-/// }
-/// scf.yield
-/// }
-/// scf.yield
+/// linalg.matmul
+/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
+/// outs(%1 : memref<16x8xf32, #map0>)
/// }
///
-/// The following tiling options are handled
diff erently in tile+fuse (compared
-/// to tile only)
+/// `tilingOptions` are used to tile the corresponding operation in `ops` (the
+/// size of the former should be same as size of the latter. Based on how
+/// tile+fuse is implemented, the fused loops are generated based on the last
+/// operation in the sequence. For example, the tile sizes for the fused loops
+/// is obtained from `tilingOptions.back()`. The following tiling options are
+/// handled
diff erently in tile+fuse (compared to tile only)
/// - Interchange of the tiling loops is not supported right now.
-/// - Distribution is only done for the tile+fuse loops. The tiled loops
-/// generated by the second tiling is not distributed.
+/// - Only the fused loops are distributed.
+struct TiledAndFusedLinalgOps {
+ /// Operation obtained by tiling the last operation in sequence of `ops`
+ /// passed to `tileAndFuseLinalgOps`.
+ LinalgOp op;
+ /// The dimension of the loops that are fused.
+ std::set<unsigned> fusedLoopDims;
+ /// The generated fused operations (created within the fused loops).
+ SmallVector<LinalgOp, 1> fusedProducers;
+ /// The fused loop generated.
+ SmallVector<Operation *, 4> fusedLoops;
+};
Optional<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
+tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions,
- const LinalgFusionOptions &fusionOptions);
+ const LinalgTilingOptions &tilingOptions);
/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
/// This is an in-place transformation controlled by `interchangeVector`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 969bea4a4549..02417bd78d9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -178,6 +178,9 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
Value shape = en.value();
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
for (auto en2 : llvm::enumerate(map.getResults())) {
+ auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
+ if (!dimExpr)
+ continue;
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
<< loopDepth << "\n");
@@ -190,49 +193,18 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
llvm_unreachable("Expect to be able to extract a shape defining loop range");
}
-/// Fuses the producer of `producerIdx` into the loop immediately enclosing
-/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
-/// is needed just before the `consumer.
-///
-/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
-/// 2 cases:
-/// 1. Buffer case: `producerIdx` is the index of the buffer in
-/// `producer.getOutputBuffers()`.
-/// 2. Tensor case: `producerIdx` is the index of the tensor in
-/// `producer.getResults()`.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
- LinalgOp consumer, unsigned consumerIdx) {
- Operation *shapeProducingOp =
- consumer.getShapedOperand(consumerIdx).getDefiningOp();
- assert((isa<SubViewOp>(shapeProducingOp) ||
- isa<SubTensorOp>(shapeProducingOp)) &&
- "SubviewOp or SubTensorOp expected");
-
- // loopToOperandRangesMaps are permutations-only by construction:
- // we can always identify a data dimension with a (at least one) loop
- // dimension.
- // TODO: extend this with range inference.
- AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
- LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
- << ", producer map: " << producerMap << "\n");
+/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
+/// provides the loop range information for the fused loops. The rest are
+/// obtained from the producer itself, since they are not tiled + fused.
+static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
+ const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
unsigned nPar = producer.getNumParallelLoops();
unsigned nRed = producer.getNumReductionLoops();
unsigned nWin = producer.getNumWindowLoops();
SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
-
- // Iterate over dimensions identified by the producer map for `producerIdx`.
- // This defines a subset of the loop ranges that we need to complete later.
- auto loc = consumer.getLoc();
- for (auto en : llvm::enumerate(producerMap.getResults())) {
- unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
- loopRanges[posInProducerLoop] =
- isa<SubViewOp>(shapeProducingOp)
- ? cast<SubViewOp>(shapeProducingOp)
- .getOrCreateRanges(b, loc)[en.index()]
- : cast<SubTensorOp>(shapeProducingOp)
- .getOrCreateRanges(b, loc)[en.index()];
- }
+ for (auto fusedLoops : fusedLoopsAndRanges)
+ loopRanges[fusedLoops.first] = fusedLoops.second;
// Iterate over all dimensions. For the dimensions not identified by the
// producer map for `producerIdx`, we need to explicitly compute the shape
@@ -250,7 +222,45 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
}
}
- return cloneWithLoopRanges(b, loc, producer, loopRanges);
+ return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
+}
+
+/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
+/// expected to be defined by a subview op or a subtensor op.
+static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
+ Value shapedOperand, unsigned dim) {
+ Operation *shapeProducingOp = shapedOperand.getDefiningOp();
+ if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
+ return subViewOp.getOrCreateRanges(b, loc)[dim];
+ if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
+ return subTensorOp.getOrCreateRanges(b, loc)[dim];
+ llvm_unreachable("SubviewOp or SubTensorOp expected");
+}
+
+/// Fuses the producer of `producerIdx` into the loop immediately enclosing
+/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
+/// is needed just before the `consumer.
+///
+/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
+/// 2 cases:
+/// 1. Buffer case: `producerIdx` is the index of the buffer in
+/// `producer.getOutputBuffers()`.
+/// 2. Tensor case: `producerIdx` is the index of the tensor in
+/// `producer.getResults()`.
+static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
+ LinalgOp consumer, unsigned consumerIdx) {
+ AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
+ LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
+ << ", producer map: " << producerMap << "\n");
+ DenseMap<unsigned, Range> fusedLoopsAndRanges;
+ Location loc = consumer.getLoc();
+ Value shapedOperand = consumer.getShapedOperand(consumerIdx);
+ for (auto en : llvm::enumerate(producerMap.getResults())) {
+ unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
+ fusedLoopsAndRanges[posInProducerLoop] =
+ getRangeFromOperandShape(b, loc, shapedOperand, en.index());
+ }
+ return fuse(b, producer, fusedLoopsAndRanges);
}
// Encode structural fusion safety preconditions.
@@ -525,6 +535,69 @@ using FusableOpDependencesTy = llvm::MapVector<
Operation *,
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
+/// Returns the mapping from iterations in the consumer that write to the same
+/// location as the iterations in the producer. To do so use
+/// - indexing map of the fused view in the consumer : consumerIndexMap
+/// - indexing map of the fused view in the producer : producerIndexMap
+/// consumerLoopToProducerLoop =
+/// inverse(producerIndexMap).compose(consumerIndexMap)
+static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
+ LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
+ auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
+ AffineMap producerIndexingMap =
+ producer.getIndexingMap(dependence.dependentOpView.operandIndex);
+ auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
+ AffineMap consumerIndexingMap =
+ consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
+
+ AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
+ producer.iterator_types().getValue(), producerIndexingMap);
+ if (!prunedProducerIndexingMap.isPermutation())
+ return None;
+
+ if (consumerIndexingMap.getNumResults() !=
+ prunedProducerIndexingMap.getNumResults())
+ return None;
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t producerMap : ";
+ producerIndexingMap.print(llvm::dbgs());
+ llvm::dbgs() << " pruned : ";
+ prunedProducerIndexingMap.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ llvm::dbgs() << "\t consumerMap : ";
+ consumerIndexingMap.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
+ if (!invProducerIndexMap)
+ return None;
+
+ return invProducerIndexMap.compose(consumerIndexingMap);
+}
+
+/// Given a projected permutation `map`, returns true if the map changes the
+/// order in which the fused loop dimension appear.
+static bool doesTransposeAccess(AffineMap map,
+ const std::set<unsigned> &fusableLoops) {
+ Optional<unsigned> lastFusableLoop;
+ for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
+ return expr.cast<AffineDimExpr>().getPosition();
+ })) {
+ if (!fusableLoops.count(pos))
+ continue;
+ if (!lastFusableLoop) {
+ lastFusableLoop = pos;
+ continue;
+ }
+ if (pos <= lastFusableLoop.getValue())
+ return true;
+ lastFusableLoop = pos;
+ }
+ return false;
+}
+
/// Returns the positions of the loop in `op` that can be tiled based on the
/// operations that are to be fused with it. For example, in a
///
@@ -538,13 +611,7 @@ using FusableOpDependencesTy = llvm::MapVector<
/// 2. Of the parallel loops only some can be fused. Only those loops can be
/// fused such where the fusable loops iteration space only touches one tile
/// of the fused operation. This is because the producer (which is writing
-/// the fused subview) has update semantics. To compute this,
-/// a. Find the mapping from iterations in the consumer that write to the
-/// same location as the iterations in the producer. To do so use
-/// - indexing map of the fused view in the consumer : consumerIndexMap
-/// - indexing map of the fused view in the producer : producerIndexMap
-/// consumerLoopToProducerLoop =
-/// inverse(producerIndexMap).compose(consumerIndexMap)
+/// the fused subview) has update semantics.
///
/// Since an inverse computation is needed, we need to consider the projection
/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops
@@ -582,8 +649,9 @@ using FusableOpDependencesTy = llvm::MapVector<
/// submap with only parallel loops = affine_map<(i, j) -> (j)>
/// Fused dimensions : j
static std::set<unsigned>
-collectTileAndFuseLoops(LinalgOp op,
- const FusableOpDependencesTy &fusableDependences) {
+collectFusableLoops(ArrayRef<LinalgOp> ops,
+ const FusableOpDependencesTy &fusableDependences) {
+ assert(!ops.empty());
auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
return linalgOp.iterator_types()
.getValue()
@@ -594,88 +662,57 @@ collectTileAndFuseLoops(LinalgOp op,
.size();
};
- LLVM_DEBUG({
- llvm::dbgs() << "Op : ";
- op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n";
- });
-
- size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
- for (auto dependence : fusableDependences) {
- linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
+ size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
+ for (auto op : ops.drop_back()) {
numOuterParallelLoops =
- std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer));
+ std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
}
std::set<unsigned> fusableLoops;
auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
fusableLoops.insert(range.begin(), range.end());
- for (auto dependence : fusableDependences) {
- LLVM_DEBUG({
- llvm::dbgs() << "\t fusable :";
- for (unsigned i : fusableLoops)
- llvm::dbgs() << " " << i;
- llvm::dbgs() << "\n";
- });
- linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
-
- assert(!dependence.second.empty() &&
- "unexpected producer but not dependences");
- AffineMap producerIndexingMap = producer.getIndexingMap(
- dependence.second.front().dependentOpView.operandIndex);
- AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
- producer.iterator_types().getValue(), producerIndexingMap);
- if (!prunedProducerIndexingMap.isPermutation())
- return {};
-
- AffineMap consumerIndexingMap = op.getIndexingMap(
- dependence.second.front().indexingOpView.operandIndex);
- if (consumerIndexingMap.getNumResults() !=
- prunedProducerIndexingMap.getNumResults())
- return {};
-
- LLVM_DEBUG({
- llvm::dbgs() << "\t producerMap : ";
- producerIndexingMap.print(llvm::dbgs());
- llvm::dbgs() << " pruned : ";
- prunedProducerIndexingMap.print(llvm::dbgs());
- llvm::dbgs() << "\n";
- llvm::dbgs() << "\t consumerMap : ";
- consumerIndexingMap.print(llvm::dbgs());
- llvm::dbgs() << "\n";
- });
-
- AffineMap invProducerIndexMap =
- inversePermutation(prunedProducerIndexingMap);
- if (!invProducerIndexMap)
- return {};
-
- AffineMap consumerLoopToProducerLoop =
- invProducerIndexMap.compose(consumerIndexingMap);
-
- LLVM_DEBUG({
- llvm::dbgs() << "\t consumerLoopToProducerLoop : ";
- consumerLoopToProducerLoop.print(llvm::dbgs());
- });
-
- std::set<unsigned> candidates;
- for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) {
- AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>();
- if (!dimExpr)
- continue;
- unsigned position = dimExpr.getPosition();
- if (fusableLoops.count(position))
- candidates.insert(position);
+
+ for (auto op : reverse(ops)) {
+ for (auto dependence : fusableDependences.lookup(op)) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t fusable :";
+ for (unsigned i : fusableLoops)
+ llvm::dbgs() << " " << i;
+ llvm::dbgs() << "\n";
+ });
+
+ Optional<AffineMap> consumerLoopToProducerLoop =
+ getConsumerLoopToProducerLoopMap(dependence);
+ if (!consumerLoopToProducerLoop) {
+ op.emitRemark("failed to get map from consumer loop to producer loop");
+ return {};
+ }
+ // todo: This condition is only an implementation limitation. When fusing
+ // the operation, if the accesses in the producer/consumer are transposes
+ // of each other, the loop bounds for the tiled producer can be
+ // manipulated accordingly. This requires some additional bookkeeping in
+ // the implementation of tile+fuse that is defered to later.
+ if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
+ op.emitRemark("unhandled fusion when fusion requires permutation");
+ return {};
+ }
+
+ std::set<unsigned> candidates;
+ for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
+ unsigned position = expr.cast<AffineDimExpr>().getPosition();
+ if (fusableLoops.count(position))
+ candidates.insert(position);
+ }
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t candidates :";
+ for (unsigned i : candidates)
+ llvm::dbgs() << " " << i;
+ llvm::dbgs() << "\n";
+ });
+ if (candidates.empty())
+ return {};
+ std::swap(candidates, fusableLoops);
}
- LLVM_DEBUG({
- llvm::dbgs() << "\t candidates :";
- for (unsigned i : candidates)
- llvm::dbgs() << " " << i;
- llvm::dbgs() << "\n";
- });
- if (candidates.empty())
- return {};
- std::swap(candidates, fusableLoops);
}
return fusableLoops;
@@ -683,60 +720,69 @@ collectTileAndFuseLoops(LinalgOp op,
/// Find all dependences that are to be fusable.
static FusableOpDependencesTy
-findAllFusableDependences(LinalgOp op,
- const LinalgDependenceGraph &dependenceGraph,
- const LinalgFusionOptions &fusionOptions) {
+findAllFusableDependences(ArrayRef<LinalgOp> ops,
+ const LinalgDependenceGraph &dependenceGraph) {
FusableOpDependencesTy fusableDependences;
// TODO: Currently fusion would not be legal if the fusable dependence is to
// the same producer but
diff erent indexing map in the consumer. Fix this, but
// in the meanwhile disallow such a fusion.
DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
- for (auto operandIndex : fusionOptions.indicesToFuse) {
- auto fusableDependence =
- findFusableProducer(op, operandIndex, dependenceGraph);
- if (!fusableDependence)
- return FusableOpDependencesTy{};
- LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
- // Do not fuse dependences that are to operations not in the same basic
- // block. This avoid moving fused operations across loops that might
- // themselves carry dependency making the fusion illegal.
- if (producerOp.getOperation()->getBlock() !=
- op.getOperation()->getBlock()) {
- op.emitRemark("unhandled fusion of ops in
diff erent basic blocks");
- return FusableOpDependencesTy{};
- }
- // Make sure that the indexing map of the view used for fusion in the
- // producer is a projected permutation.
- unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
- AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
- if (!producerMap.isProjectedPermutation()) {
- op.emitRemark("unhandled non permutation indexing map for fused view in "
- "producer for operand at index ")
- << operandIndex;
- return FusableOpDependencesTy{};
- }
+ for (LinalgOp op : reverse(ops)) {
+ for (auto operandIndex :
+ llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
+ Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
+ fusableDependence =
+ findFusableProducer(op, operandIndex, dependenceGraph);
+ if (!fusableDependence)
+ continue;
+ LinalgOp producerOp =
+ cast<LinalgOp>(fusableDependence->dependentOpView.op);
+ // Do not fuse dependences that are to operations not in the same basic
+ // block. This avoid moving fused operations across loops that might
+ // themselves carry dependency making the fusion illegal.
+ if (producerOp.getOperation()->getBlock() !=
+ op.getOperation()->getBlock()) {
+ op.emitRemark("unhandled fusion of ops in
diff erent basic blocks");
+ return FusableOpDependencesTy{};
+ }
+ // Make sure that the indexing map of the view used for fusion in the
+ // producer is a projected permutation.
+ unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
+ AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
+ if (!producerMap.isProjectedPermutation()) {
+ op.emitRemark(
+ "unhandled non permutation indexing map for fused view in "
+ "producer for operand at index ")
+ << operandIndex;
+ return FusableOpDependencesTy{};
+ }
- unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
- AffineMap consumerMap = op.getIndexingMap(consumerIdx);
- if (!consumerMap.isProjectedPermutation()) {
- op.emitRemark(
- "unhandled case where indexing map for fused view in the consumer is "
- "not a projected permutation while fusing at index ")
- << operandIndex;
- return FusableOpDependencesTy{};
- }
+ unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
+ AffineMap consumerMap = op.getIndexingMap(consumerIdx);
+ if (!consumerMap.isProjectedPermutation()) {
+ op.emitRemark(
+ "unhandled case where indexing map for fused view in the consumer "
+ "is "
+ "not a projected permuration while fusing at index ")
+ << operandIndex;
+ return FusableOpDependencesTy{};
+ }
- // Check if the producer is already a fusion candidate. Cannot fuse this
- // dependence if it has a
diff erent indexing map when used in the consumer.
- if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
- fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
- op.emitRemark("unhandled fusion to the same producer but with
diff erent "
- "indexing maps");
- return FusableOpDependencesTy{};
- }
- fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
+ // Check if the producer is already a fusion candidate. Cannot fuse this
+ // dependence if it has a
diff erent indexing map when used in the
+ // consumer.
+ if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
+ fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
+ op.emitRemark(
+ "unhandled fusion to the same producer but with
diff erent "
+ "indexing maps");
+ return FusableOpDependencesTy{};
+ }
+ fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
- fusableDependences[producerOp.getOperation()].push_back(*fusableDependence);
+ fusableDependences[producerOp.getOperation()].push_back(
+ *fusableDependence);
+ }
}
return fusableDependences;
}
@@ -747,136 +793,120 @@ static bool isZero(Value v) {
return false;
}
+/// Tile the fused loops in the root operation, by setting the tile sizes for
+/// all other loops to zero (those will be tiled later).
+static Optional<TiledLinalgOp> tileRootOperation(
+ OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
+ const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
+ SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
+ auto zero = std_constant_index(0);
+ for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
+ if (!fusedLoops.count(i))
+ tileSizes[i] = zero;
+ LinalgTilingOptions tileFusedLoopsOptions = options;
+ tileFusedLoopsOptions.setTileSizes(tileSizes);
+ return tileLinalgOp(builder, op, tileFusedLoopsOptions);
+}
+
+/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
+/// to be a tiled operation such that it is valid to fuse all operations in
+/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
+/// `tiledOp`.
+static SmallVector<LinalgOp, 1>
+fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
+ ArrayRef<LinalgOp> fusionCandidates,
+ const FusableOpDependencesTy &fusableDependences,
+ const std::set<unsigned> &fusedLoops) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPoint(tiledOp);
+ DenseMap<unsigned, Range> fusedLoopsAndRanges;
+ for (unsigned loop : fusedLoops) {
+ ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop);
+ fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
+ builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
+ }
+ SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
+ for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
+ LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
+ fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
+ builder.setInsertionPoint(fusedOp);
+ }
+ return fusedOps;
+}
+
template <typename LoopType>
static Optional<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
+tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions,
- const LinalgFusionOptions &fusionOptions) {
- assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
- // Some of the tiling options might not be supportable with tile and fuse.
- // TODO: Support interchange with tile + fuse.
+ const LinalgTilingOptions &tilingOptions) {
+ if (ops.empty())
+ return llvm::None;
+ LinalgOp rootOp = ops.back();
+ for (auto op : enumerate(ops)) {
+ // TODO: Nothing in the fusion of sequence of ops is specific to
+ // buffers. This check can be removed after it is tested on tensors.
+ LinalgOp linalgOp = op.value();
+ if (!linalgOp.hasBufferSemantics()) {
+ linalgOp.emitError("tile and fuse only tested for buffer operation");
+ return llvm::None;
+ }
+ }
+ // TODO: Support interchange with tile + fuse. This might actually help do
+ // better fusion.
if (!tilingOptions.interchangeVector.empty()) {
- op.emitError("unable to handle tile and fuse with interchange");
+ rootOp.emitError("unable to handle tile and fuse with interchange");
return llvm::None;
}
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(op);
- ScopedContext scope(rewriter, op.getLoc());
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPoint(rootOp);
+ ScopedContext scope(builder, rootOp.getLoc());
// Find all the producers.
FusableOpDependencesTy fusableDependences =
- findAllFusableDependences(op, dependenceGraph, fusionOptions);
+ findAllFusableDependences(ops, dependenceGraph);
if (fusableDependences.empty())
return llvm::None;
- // Enforce the convention that "tiling by zero" skips tiling a particular
- // dimension. This convention is significantly simpler to handle instead of
- // adjusting affine maps to account for missing dimensions.
- auto nLoops = op.getNumLoops();
- SmallVector<Value, 4> tileSizeVector =
- tilingOptions.tileSizeComputationFunction(rewriter, op);
- if (tileSizeVector.size() < nLoops) {
- auto zero = std_constant_index(0);
- tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
- }
-
TiledAndFusedLinalgOps ret;
-
// Find the loops that can be tiled and fused.
- std::set<unsigned> tileFuseLoops =
- collectTileAndFuseLoops(op, fusableDependences);
+ ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
// If there are no fusable dependences or there are no tile+fusable loops,
// just return.
- if (tileFuseLoops.empty()) {
+ if (ret.fusedLoopDims.empty()) {
return llvm::None;
}
- // Get the tile sizes for the first and second tiling steps. For the first
- // step the tile size are set to zero for the loops that arent
- // fused. Similarly for the second step, the tile sizes are set to zero for
- // the loops that are fused. For example, if for the following input
- //
- // ```
- // linalg.add ins(%a, %b) outs(%c)
- // linalg.matmul ins(%d, %c) outs(%e)
- // ```
- //
- // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
- // respectively, and since only `j` can be tiled and fused. The tile sizes
- // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
- // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
- // the tiled matmul generated by the first tiling step.
- SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
- for (auto tileSize : enumerate(tileSizeVector)) {
- auto zero = std_constant_index(0);
- if (tileFuseLoops.count(tileSize.index())) {
- tileAndFuseSizes.push_back(tileSize.value());
- tileSizes.push_back(zero);
- } else {
- tileSizes.push_back(tileSize.value());
- tileAndFuseSizes.push_back(zero);
- }
- }
-
- // Tile for the loops that can be fused.
- LinalgTilingOptions firstTilingOptions = tilingOptions;
- firstTilingOptions.setTileSizes(tileAndFuseSizes);
- Optional<TiledLinalgOp> firstTiledOp =
- tileLinalgOp(rewriter, op, firstTilingOptions);
- if (!firstTiledOp)
+ // Tile the fused loops in the last operation in the list.
+ SmallVector<Value, 4> tileSizeVector =
+ tilingOptions.tileSizeComputationFunction(builder, rootOp);
+ Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
+ builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
+ if (!tiledRootOp) {
+ rootOp.emitError("failed to tile the fused loops");
return llvm::None;
- ret.op = firstTiledOp->op;
- ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
-
- rewriter.setInsertionPoint(ret.op);
- // Fuse the operands.
- for (auto dependence : fusableDependences) {
- LinalgOp producerOp = cast<LinalgOp>(dependence.first);
- unsigned producerIdx =
- dependence.second.front().dependentOpView.operandIndex;
- unsigned consumerIdx =
- dependence.second.front().indexingOpView.operandIndex;
- LinalgOp fusedOp = fuse(rewriter, producerOp,
- producerOp.getOutputIndex(producerIdx).getValue(),
- ret.op, consumerIdx);
- ret.fusedProducers.push_back(fusedOp);
- ret.originalProducers.push_back(producerOp);
- }
-
- if (!llvm::all_of(tileSizes, isZero)) {
- // Tile the remaining loops of the root operation.
- LinalgTilingOptions secondTilingOptions = tilingOptions;
- // The distribution is done only for the tile+fused loops.
- secondTilingOptions.distribution = llvm::None;
- secondTilingOptions.setTileSizes(tileSizes);
- Optional<TiledLinalgOp> secondTiledOp =
- tileLinalgOp(rewriter, ret.op, secondTilingOptions);
- if (!secondTiledOp)
- return llvm::None;
- ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
- secondTiledOp->loops.end());
- rewriter.eraseOp(ret.op);
- ret.op = secondTiledOp->op;
}
+ ret.op = tiledRootOp->op;
+ ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
+ // Fuse the other operations into the fused inter-tile loops produced above.
+ ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
+ fusableDependences, ret.fusedLoopDims);
return ret;
}
Optional<TiledAndFusedLinalgOps>
-mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
+mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions,
- const LinalgFusionOptions &fusionOptions) {
+ const LinalgTilingOptions &tilingOptions) {
switch (tilingOptions.loopType) {
case LinalgTilingLoopType::Loops:
- return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
- tilingOptions, fusionOptions);
+ return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
+ tilingOptions);
case LinalgTilingLoopType::ParallelLoops:
return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
- rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
+ builder, ops, dependenceGraph, tilingOptions);
default:;
}
return llvm::None;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 836cc28e0a47..a855c07cb8d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -165,17 +165,69 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
if (!linalgOp.hasBufferSemantics())
return failure();
+ DenseSet<Operation *> producers;
+ producers.insert(linalgOp);
+ for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) {
+ if (!fusionOptions.indicesToFuse.count(
+ dependence.indexingOpView.operandIndex))
+ continue;
+ if (isa<LinalgOp>(dependence.dependentOpView.op))
+ producers.insert(dependence.dependentOpView.op);
+ }
+
+ SmallVector<LinalgOp, 1> fusionOps;
+ for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
+ ++it) {
+ auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
+ if (producerLinalgOp && producers.count(producerLinalgOp))
+ fusionOps.push_back(producerLinalgOp);
+ }
+ fusionOps.push_back(linalgOp);
+
+ SmallVector<Value, 4> tileSizes =
+ tilingOptions.tileSizeComputationFunction(rewriter, op);
+ LinalgTilingOptions instanceTilingOptions = tilingOptions;
+ instanceTilingOptions.setTileSizes(tileSizes);
Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
- rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
+ rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
if (!tiledAndFusedOps)
return failure();
+
+ // Tile the unfused loops;
+ SmallVector<Value, 4> unfusedLoopTileSizes;
+ Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0);
+ for (auto tileSize : enumerate(tileSizes)) {
+ if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
+ unfusedLoopTileSizes.push_back(zero);
+ else
+ unfusedLoopTileSizes.push_back(tileSize.value());
+ }
+ // Tile the loop only if there is a non-zero tile size.
+ if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
+ unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
+ if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
+ if (auto cst = val.getDefiningOp<ConstantIndexOp>())
+ return cst.getValue() != 0;
+ return true;
+ })) {
+ LinalgTilingOptions unfusedTilingOptions = tilingOptions;
+ unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
+ Optional<TiledLinalgOp> unfusedTiledOp =
+ tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
+ if (!unfusedTiledOp)
+ return failure();
+ rewriter.eraseOp(tiledAndFusedOps->op);
+ tiledAndFusedOps->op = unfusedTiledOp->op;
+ }
+
marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
}
- for (auto origProducerOp : tiledAndFusedOps->originalProducers)
+ for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
originalOpMarker.replaceLinalgMarker(rewriter,
origProducerOp.getOperation());
+ }
rewriter.updateRootInPlace(
op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
return success();
diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
index 2ddc66651db2..fa471811ef4e 100644
--- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
@@ -47,7 +47,9 @@ module {
// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
// CHECK: %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]]
-// CHECK: linalg.fill(%[[SV3]], %[[CST]])
+// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
+// CHECK: linalg.fill(%[[SV3_2]], %[[CST]])
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
@@ -109,9 +111,12 @@ module {
// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]]
// CHECK: %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]]
// CHECK-SAME: [%[[M]], %[[TILE_N_2]]]
+// CHECK: %[[K_2:.+]] = dim %[[ARG1]], %[[C0]]
// CHECK: %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]]
-// CHECK-SAME: [%[[K]], %[[TILE_N]]]
-// CHECK: linalg.copy(%[[SV3]], %[[SV1]])
+// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
+// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][0, %[[IV0]]]
+// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
+// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
// CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_producer"
// CHECK-NOT: linalg.fill
// CHECK-DAG: %[[M_2:.+]] = dim %[[ARG0]], %[[C0]]
@@ -186,11 +191,16 @@ module {
// CHECK: %[[N:.+]] = dim %[[ARG3]], %[[C1]]
// CHECK: %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
+// CHECK: %[[SV2_2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M]], %[[N]]]
+// CHECK: %[[K_2:.+]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K]]]
-// CHECK: linalg.copy(%[[SV3]], %[[SV1]])
+// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
+// CHECK: %[[SV3_2:.+]] = subview %[[ARG1]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
+// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer"
-// CHECK: linalg.fill(%[[SV2]], %[[CST]])
+// CHECK: linalg.fill(%[[SV2_2]], %[[CST]])
// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer"
// CHECK-DAG: %[[N_2:.+]] = dim %[[ARG2]], %[[C1]]
// CHECK: scf.parallel (%[[IV1:.+]]) =
@@ -261,15 +271,18 @@ module {
// CHECK: %[[N:.+]] = dim %[[ARG4]], %[[C1]]
// CHECK: %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
+// CHECK: %[[K2_2:.+]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[K1:.+]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[K1]]]
-// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]]
+// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]]
+// CHECK: %[[SV1_2:.+]] = subview %[[ARG2]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M]], %[[K2_2]]]
// CHECK: linalg.matmul
// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
// CHECK-SAME: ins(%[[SV3]], %[[SV4]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME: outs(%[[SV1]] : memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME: outs(%[[SV1_2]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK-DAG: %[[N_2:.+]] = dim %[[ARG3]], %[[C1]]
// CHECK: scf.parallel (%[[IV1:.+]]) =
// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
new file mode 100644
index 000000000000..a02c878ef341
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
+
+module {
+ func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+ %arg2: memref<?xf32>, %arg3 : memref<?x?xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = dim %arg0, %c0 : memref<?x?xf32>
+ %d1 = dim %arg1, %c1 : memref<?x?xf32>
+ %0 = alloc(%d0, %d1) : memref<?x?xf32>
+ linalg.fill(%0, %cst) : memref<?x?xf32>, f32
+ linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%0 : memref<?x?xf32>)
+ linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0, %arg2 : memref<?x?xf32>, memref<?xf32>)
+ outs(%arg3 : memref<?x?xf32>) {
+ ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) :
+ %5 = addf %arg4, %arg5 : f32
+ linalg.yield %5 : f32
+ }
+ return
+ }
+}
+
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK: func @three_op_fusion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK: %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
+// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
+// CHECK-DAG: %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]]
+// CHECK-DAG: %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]]
+// CHECK-DAG: %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]]
+// CHECK-DAG: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
+// CHECK-DAG: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]]
+// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}})
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
+// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
+// CHECK-SAME: outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>)
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]]
+// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
+// CHECK-SAME: outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
+// CHECK: scf.yield
+// CHECK: }
+
+// -----
+
+module {
+ func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+ %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
+ %arg4: memref<?x?xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %m = dim %arg0, %c0 : memref<?x?xf32>
+ %n1 = dim %arg1, %c1 : memref<?x?xf32>
+ %n2 = dim %arg2, %c1 : memref<?x?xf32>
+ %n3 = dim %arg3, %c1 : memref<?x?xf32>
+ %0 = alloc(%m, %n1) : memref<?x?xf32>
+ %1 = alloc(%m, %n2) : memref<?x?xf32>
+ linalg.fill(%0, %cst) : memref<?x?xf32>, f32
+ linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%0 : memref<?x?xf32>)
+ linalg.fill(%1, %cst) : memref<?x?xf32>, f32
+ linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%1 : memref<?x?xf32>)
+ linalg.fill(%arg4, %cst) : memref<?x?xf32>, f32
+ linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%arg4 : memref<?x?xf32>)
+ return
+ }
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK: func @sequence_of_matmul
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C16:.+]] = constant 16 : index
+// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[N1:.+]] = dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[N2:.+]] = dim %[[ARG2]], %[[C1]]
+// CHECK: %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]])
+// CHECK: %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]])
+// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
+// CHECK-SAME: step (%[[C16]]) {
+// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+// CHECK: %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
+// CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]]
+// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
+// CHECK: %[[N3:.+]] = dim %[[ARG4]], %[[C1]]
+// CHECK: %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
+// CHECK: %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M]], %[[N3]]]
+// CHECK: %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
+// CHECK: %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]]
+// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]]
+// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]]
+// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
+// CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}})
+// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]]
+// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME: outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>)
+// CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}})
+// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]]
+// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME: outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
+// CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}})
+// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]]
+// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
+// CHECK-SAME: outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
+// CHECK: scf.yield
+// CHECK: }
+
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index eb9e3a533138..5289b2d1055f 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -197,6 +197,44 @@ struct TestLinalgGreedyFusion
}
}
};
+
+/// Pass to test tile and fuse of sequence of operations. Intended only for
+/// testing.
+struct TestLinalgTileAndFuseSequencePass
+ : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
+ TestLinalgTileAndFuseSequencePass() = default;
+ TestLinalgTileAndFuseSequencePass(
+ const TestLinalgTileAndFuseSequencePass &pass){};
+
+ ListOption<int64_t> tileSizes{
+ *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
+ }
+
+ void runOnFunction() override {
+ FuncOp funcOp = getOperation();
+ auto &blocks = funcOp.getBody().getBlocks();
+ if (!llvm::hasSingleElement(blocks)) {
+ return;
+ }
+ SmallVector<LinalgOp, 2> linalgOps =
+ llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
+ Aliases aliases;
+ LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
+ OpBuilder builder(funcOp.getContext());
+ Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
+ builder, linalgOps, dependenceGraph,
+ LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
+ LinalgTilingLoopType::ParallelLoops));
+ if (!tileAndFuseOps)
+ return signalPassFailure();
+ for (auto op : linalgOps)
+ op.erase();
+ }
+};
} // namespace
namespace mlir {
@@ -211,5 +249,12 @@ void registerTestLinalgGreedyFusion() {
"test-linalg-greedy-fusion",
"Test Linalg fusion by applying a greedy test transformation.");
}
+void registerTestLinalgTileAndFuseSequencePass() {
+ PassRegistration<TestLinalgTileAndFuseSequencePass>
+ testTileAndFuseSequencePass(
+ "test-linalg-tile-and-fuse",
+ "Test Linalg tiling and fusion of a sequence of Linalg operations.");
+}
+
} // namespace test
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 4771b11b20e4..a0e36cf82534 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -74,6 +74,7 @@ void registerTestLinalgCodegenStrategy();
void registerTestLinalgFusionTransforms();
void registerTestLinalgGreedyFusion();
void registerTestLinalgHoisting();
+void registerTestLinalgTileAndFuseSequencePass();
void registerTestLinalgTransforms();
void registerTestLivenessPass();
void registerTestLoopFusion();
@@ -141,6 +142,7 @@ void registerTestPasses() {
test::registerTestLinalgFusionTransforms();
test::registerTestLinalgGreedyFusion();
test::registerTestLinalgHoisting();
+ test::registerTestLinalgTileAndFuseSequencePass();
test::registerTestLinalgTransforms();
test::registerTestLivenessPass();
test::registerTestLoopFusion();
More information about the Mlir-commits
mailing list