[Mlir-commits] [mlir] 0caa82e - Revert "[mlir][Linalg] Fuse sequence of Linalg operation (on buffers)"
Mikhail Goncharov
llvmlistbot at llvm.org
Fri Nov 20 04:13:12 PST 2020
Author: Mikhail Goncharov
Date: 2020-11-20T13:12:54+01:00
New Revision: 0caa82e2ac53b2ff475531086dfe648fb2d6158a
URL: https://github.com/llvm/llvm-project/commit/0caa82e2ac53b2ff475531086dfe648fb2d6158a
DIFF: https://github.com/llvm/llvm-project/commit/0caa82e2ac53b2ff475531086dfe648fb2d6158a.diff
LOG: Revert "[mlir][Linalg] Fuse sequence of Linalg operation (on buffers)"
This reverts commit f8284d21a8e294d58a0acd4b8b2e906d7a9f110c.
Revert "[mlir][Linalg] NFC: Expose some utility functions used for promotion."
This reverts commit 0c59f51592ef5c014352994369f5216c6376fae1.
Revert "Remove unused isZero function"
This reverts commit 0f9f0a4046e11c2b4c130640f343e3b2b5db08c1.
Change f8284d21 led to multiple failures in IREE compilation.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/fusion-pattern.mlir
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
mlir/test/Dialect/Linalg/fusion-sequence.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 87ff2a97d93f..8d531a1e343a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -37,6 +37,14 @@ struct TiledLinalgOp {
SmallVector<Value, 4> tensorResults;
};
+struct TiledAndFusedLinalgOps {
+ LinalgOp op;
+ SmallVector<LinalgOp, 1> fusedProducers;
+ SmallVector<LinalgOp, 1> originalProducers;
+ SmallVector<Operation *, 4> fusedLoops;
+ SmallVector<Operation *, 4> unfusedLoops;
+};
+
/// Populates patterns for vectorization of all ConvN-D ops.
void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@@ -65,11 +73,14 @@ void populateLinalgBufferizePatterns(MLIRContext *context,
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
const LinalgTilingOptions &options);
-/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
-/// proceeds as follows:
-/// - Find outer parallel loops in these ops that can be fused.
-/// - Tile fusable outer parallel loops of the last operation in the sequence.
-/// - Fuse the remaining operations with the tiled operation
+/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in
+/// three steps
+/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile
+/// + fuse loops).
+/// - Tile just these loops of the consumer (root operation) and fuse with
+/// the producer.
+/// - Tile again the tiled consumer operation produced above to do rest of
+/// the tiling specified by the `tilingOptions`.
///
/// For example, consider the sequence of matmul below
///
@@ -96,39 +107,36 @@ Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
/// : memref<256x32xf32> to memref<16x32xf32, #map0>
/// %3 = subview %arg1[0, 0] [32, 32] [1, 1]
/// : memref<32x32xf32> to memref<32x32xf32, #map1>
-/// %4 = subview %arg3[0, 0] [32, 32] [1, 1]
-/// : memref<32x32xf32> to memref<32x32xf32, #map1>
/// linalg.matmul
/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
/// outs(%0 : memref<16x32xf32, #map0>)
-/// linalg.matmul
-/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
-/// outs(%1 : memref<16x8xf32, #map0>)
+/// scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) {
+/// scf.for %arg7 = %c0 to %c32 step %c4 {
+/// %4 = subview %0[0, %arg7] [16, 4] [1, 1]
+/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0>
+/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1]
+/// : memref<32x32xf32> to memref<4x8xf32, #map0>
+/// %6 = subview %1[0, %arg6] [16, 8] [1, 1]
+/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0>
+/// linalg.matmul
+/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
+/// outs(%6 : memref<16x8xf32, #map0>)
+/// }
+/// scf.yield
+/// }
+/// scf.yield
/// }
///
-/// `tilingOptions` are used to tile the corresponding operation in `ops` (the
-/// size of the former should be same as size of the latter. Based on how
-/// tile+fuse is implemented, the fused loops are generated based on the last
-/// operation in the sequence. For example, the tile sizes for the fused loops
-/// is obtained from `tilingOptions.back()`. The following tiling options are
-/// handled
diff erently in tile+fuse (compared to tile only)
+/// The following tiling options are handled
diff erently in tile+fuse (compared
+/// to tile only)
/// - Interchange of the tiling loops is not supported right now.
-/// - Only the fused loops are distributed.
-struct TiledAndFusedLinalgOps {
- /// Operation obtained by tiling the last operation in sequence of `ops`
- /// passed to `tileAndFuseLinalgOps`.
- LinalgOp op;
- /// The dimension of the loops that are fused.
- std::set<unsigned> fusedLoopDims;
- /// The generated fused operations (created within the fused loops).
- SmallVector<LinalgOp, 1> fusedProducers;
- /// The fused loop generated.
- SmallVector<Operation *, 4> fusedLoops;
-};
+/// - Distribution is only done for the tile+fuse loops. The tiled loops
+/// generated by the second tiling is not distributed.
Optional<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
+tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions);
+ const LinalgTilingOptions &tilingOptions,
+ const LinalgFusionOptions &fusionOptions);
/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
/// This is an in-place transformation controlled by `interchangeVector`.
@@ -234,20 +242,6 @@ struct LinalgPromotionOptions {
}
};
-/// Creates a new buffer using the `allocationFn` provided. The size of this
-/// buffer is the smallest constant bounding size along each dimension that can
-/// be computed for the size of the result of `subView`. Returns the allocated
-/// buffer as `fullLocalView` and the view that matches the size of the result
-/// of subview operation as `partialLocalView`.
-struct PromotionInfo {
- Value fullLocalView;
- Value partialLocalView;
-};
-Optional<PromotionInfo>
-promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView,
- AllocBufferCallbackFn allocationFn,
- OperationFolder *folder = nullptr);
-
/// Promotes the `subViews` into a new buffer allocated at the insertion point
/// `b`. Promotion occurs in 3 steps:
/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 1eaf8b0e709c..f5669e383368 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -17,7 +17,6 @@
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetVector.h"
using mlir::edsc::intrinsics::AffineIndexedValue;
@@ -83,13 +82,6 @@ bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
Value consumedView, LinalgOp producer);
-using FusableOpDependencesTy = llvm::MapVector<
- Operation *,
- SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
-FusableOpDependencesTy
-findAllFusableDependences(ArrayRef<LinalgOp> ops,
- const LinalgDependenceGraph &dependenceGraph);
-
/// Fuses producer into consumer if the producer is structurally feasible and
/// the fusion would not violate dependencies.
/// Implements the fusion part of the "tileAndFuse on buffers"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 45a68fcba4a2..969bea4a4549 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -178,9 +178,6 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
Value shape = en.value();
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
for (auto en2 : llvm::enumerate(map.getResults())) {
- auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
- if (!dimExpr)
- continue;
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
<< loopDepth << "\n");
@@ -193,18 +190,49 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
llvm_unreachable("Expect to be able to extract a shape defining loop range");
}
-/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
-/// provides the loop range information for the fused loops. The rest are
-/// obtained from the producer itself, since they are not tiled + fused.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
- const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
+/// Fuses the producer of `producerIdx` into the loop immediately enclosing
+/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
+/// is needed just before the `consumer.
+///
+/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
+/// 2 cases:
+/// 1. Buffer case: `producerIdx` is the index of the buffer in
+/// `producer.getOutputBuffers()`.
+/// 2. Tensor case: `producerIdx` is the index of the tensor in
+/// `producer.getResults()`.
+static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
+ LinalgOp consumer, unsigned consumerIdx) {
+ Operation *shapeProducingOp =
+ consumer.getShapedOperand(consumerIdx).getDefiningOp();
+ assert((isa<SubViewOp>(shapeProducingOp) ||
+ isa<SubTensorOp>(shapeProducingOp)) &&
+ "SubviewOp or SubTensorOp expected");
+
+ // loopToOperandRangesMaps are permutations-only by construction:
+ // we can always identify a data dimension with a (at least one) loop
+ // dimension.
+ // TODO: extend this with range inference.
+ AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
+ LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
+ << ", producer map: " << producerMap << "\n");
unsigned nPar = producer.getNumParallelLoops();
unsigned nRed = producer.getNumReductionLoops();
unsigned nWin = producer.getNumWindowLoops();
SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
- for (auto fusedLoops : fusedLoopsAndRanges)
- loopRanges[fusedLoops.first] = fusedLoops.second;
+
+ // Iterate over dimensions identified by the producer map for `producerIdx`.
+ // This defines a subset of the loop ranges that we need to complete later.
+ auto loc = consumer.getLoc();
+ for (auto en : llvm::enumerate(producerMap.getResults())) {
+ unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
+ loopRanges[posInProducerLoop] =
+ isa<SubViewOp>(shapeProducingOp)
+ ? cast<SubViewOp>(shapeProducingOp)
+ .getOrCreateRanges(b, loc)[en.index()]
+ : cast<SubTensorOp>(shapeProducingOp)
+ .getOrCreateRanges(b, loc)[en.index()];
+ }
// Iterate over all dimensions. For the dimensions not identified by the
// producer map for `producerIdx`, we need to explicitly compute the shape
@@ -222,45 +250,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
}
}
- return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
-}
-
-/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
-/// expected to be defined by a subview op or a subtensor op.
-static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
- Value shapedOperand, unsigned dim) {
- Operation *shapeProducingOp = shapedOperand.getDefiningOp();
- if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
- return subViewOp.getOrCreateRanges(b, loc)[dim];
- if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
- return subTensorOp.getOrCreateRanges(b, loc)[dim];
- llvm_unreachable("SubviewOp or SubTensorOp expected");
-}
-
-/// Fuses the producer of `producerIdx` into the loop immediately enclosing
-/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
-/// is needed just before the `consumer.
-///
-/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
-/// 2 cases:
-/// 1. Buffer case: `producerIdx` is the index of the buffer in
-/// `producer.getOutputBuffers()`.
-/// 2. Tensor case: `producerIdx` is the index of the tensor in
-/// `producer.getResults()`.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
- LinalgOp consumer, unsigned consumerIdx) {
- AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
- LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
- << ", producer map: " << producerMap << "\n");
- DenseMap<unsigned, Range> fusedLoopsAndRanges;
- Location loc = consumer.getLoc();
- Value shapedOperand = consumer.getShapedOperand(consumerIdx);
- for (auto en : llvm::enumerate(producerMap.getResults())) {
- unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
- fusedLoopsAndRanges[posInProducerLoop] =
- getRangeFromOperandShape(b, loc, shapedOperand, en.index());
- }
- return fuse(b, producer, fusedLoopsAndRanges);
+ return cloneWithLoopRanges(b, loc, producer, loopRanges);
}
// Encode structural fusion safety preconditions.
@@ -531,68 +521,9 @@ static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
return getProjectedMap(map, projectedDims);
}
-/// Returns the mapping from iterations in the consumer that write to the same
-/// location as the iterations in the producer. To do so use
-/// - indexing map of the fused view in the consumer : consumerIndexMap
-/// - indexing map of the fused view in the producer : producerIndexMap
-/// consumerLoopToProducerLoop =
-/// inverse(producerIndexMap).compose(consumerIndexMap)
-static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
- LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
- auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
- AffineMap producerIndexingMap =
- producer.getIndexingMap(dependence.dependentOpView.operandIndex);
- auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
- AffineMap consumerIndexingMap =
- consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
-
- AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
- producer.iterator_types().getValue(), producerIndexingMap);
- if (!prunedProducerIndexingMap.isPermutation())
- return None;
-
- if (consumerIndexingMap.getNumResults() !=
- prunedProducerIndexingMap.getNumResults())
- return None;
-
- LLVM_DEBUG({
- llvm::dbgs() << "\t producerMap : ";
- producerIndexingMap.print(llvm::dbgs());
- llvm::dbgs() << " pruned : ";
- prunedProducerIndexingMap.print(llvm::dbgs());
- llvm::dbgs() << "\n";
- llvm::dbgs() << "\t consumerMap : ";
- consumerIndexingMap.print(llvm::dbgs());
- llvm::dbgs() << "\n";
- });
-
- AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
- if (!invProducerIndexMap)
- return None;
-
- return invProducerIndexMap.compose(consumerIndexingMap);
-}
-
-/// Given a projected permutation `map`, returns true if the map changes the
-/// order in which the fused loop dimension appear.
-static bool doesTransposeAccess(AffineMap map,
- const std::set<unsigned> &fusableLoops) {
- Optional<unsigned> lastFusableLoop;
- for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
- return expr.cast<AffineDimExpr>().getPosition();
- })) {
- if (!fusableLoops.count(pos))
- continue;
- if (!lastFusableLoop) {
- lastFusableLoop = pos;
- continue;
- }
- if (pos <= lastFusableLoop.getValue())
- return true;
- lastFusableLoop = pos;
- }
- return false;
-}
+using FusableOpDependencesTy = llvm::MapVector<
+ Operation *,
+ SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
/// Returns the positions of the loop in `op` that can be tiled based on the
/// operations that are to be fused with it. For example, in a
@@ -607,7 +538,13 @@ static bool doesTransposeAccess(AffineMap map,
/// 2. Of the parallel loops only some can be fused. Only those loops can be
/// fused such where the fusable loops iteration space only touches one tile
/// of the fused operation. This is because the producer (which is writing
-/// the fused subview) has update semantics.
+/// the fused subview) has update semantics. To compute this,
+/// a. Find the mapping from iterations in the consumer that write to the
+/// same location as the iterations in the producer. To do so use
+/// - indexing map of the fused view in the consumer : consumerIndexMap
+/// - indexing map of the fused view in the producer : producerIndexMap
+/// consumerLoopToProducerLoop =
+/// inverse(producerIndexMap).compose(consumerIndexMap)
///
/// Since an inverse computation is needed, we need to consider the projection
/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops
@@ -645,9 +582,8 @@ static bool doesTransposeAccess(AffineMap map,
/// submap with only parallel loops = affine_map<(i, j) -> (j)>
/// Fused dimensions : j
static std::set<unsigned>
-collectFusableLoops(ArrayRef<LinalgOp> ops,
- const FusableOpDependencesTy &fusableDependences) {
- assert(!ops.empty());
+collectTileAndFuseLoops(LinalgOp op,
+ const FusableOpDependencesTy &fusableDependences) {
auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
return linalgOp.iterator_types()
.getValue()
@@ -658,245 +594,289 @@ collectFusableLoops(ArrayRef<LinalgOp> ops,
.size();
};
- size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
- for (auto op : ops.drop_back()) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Op : ";
+ op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n";
+ });
+
+ size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
+ for (auto dependence : fusableDependences) {
+ linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
numOuterParallelLoops =
- std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
+ std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer));
}
std::set<unsigned> fusableLoops;
auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
fusableLoops.insert(range.begin(), range.end());
-
- for (auto op : reverse(ops)) {
- for (auto dependence : fusableDependences.lookup(op)) {
- LLVM_DEBUG({
- llvm::dbgs() << "\t fusable :";
- for (unsigned i : fusableLoops)
- llvm::dbgs() << " " << i;
- llvm::dbgs() << "\n";
- });
-
- Optional<AffineMap> consumerLoopToProducerLoop =
- getConsumerLoopToProducerLoopMap(dependence);
- if (!consumerLoopToProducerLoop) {
- op.emitRemark("failed to get map from consumer loop to producer loop");
- return {};
- }
- // todo: This condition is only an implementation limitation. When fusing
- // the operation, if the accesses in the producer/consumer are transposes
- // of each other, the loop bounds for the tiled producer can be
- // manipulated accordingly. This requires some additional bookkeeping in
- // the implementation of tile+fuse that is defered to later.
- if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
- op.emitRemark("unhandled fusion when fusion requires permutation");
- return {};
- }
-
- std::set<unsigned> candidates;
- for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
- unsigned position = expr.cast<AffineDimExpr>().getPosition();
- if (fusableLoops.count(position))
- candidates.insert(position);
- }
- LLVM_DEBUG({
- llvm::dbgs() << "\t candidates :";
- for (unsigned i : candidates)
- llvm::dbgs() << " " << i;
- llvm::dbgs() << "\n";
- });
- if (candidates.empty())
- return {};
- std::swap(candidates, fusableLoops);
+ for (auto dependence : fusableDependences) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t fusable :";
+ for (unsigned i : fusableLoops)
+ llvm::dbgs() << " " << i;
+ llvm::dbgs() << "\n";
+ });
+ linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
+
+ assert(!dependence.second.empty() &&
+ "unexpected producer but not dependences");
+ AffineMap producerIndexingMap = producer.getIndexingMap(
+ dependence.second.front().dependentOpView.operandIndex);
+ AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
+ producer.iterator_types().getValue(), producerIndexingMap);
+ if (!prunedProducerIndexingMap.isPermutation())
+ return {};
+
+ AffineMap consumerIndexingMap = op.getIndexingMap(
+ dependence.second.front().indexingOpView.operandIndex);
+ if (consumerIndexingMap.getNumResults() !=
+ prunedProducerIndexingMap.getNumResults())
+ return {};
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t producerMap : ";
+ producerIndexingMap.print(llvm::dbgs());
+ llvm::dbgs() << " pruned : ";
+ prunedProducerIndexingMap.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ llvm::dbgs() << "\t consumerMap : ";
+ consumerIndexingMap.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ AffineMap invProducerIndexMap =
+ inversePermutation(prunedProducerIndexingMap);
+ if (!invProducerIndexMap)
+ return {};
+
+ AffineMap consumerLoopToProducerLoop =
+ invProducerIndexMap.compose(consumerIndexingMap);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t consumerLoopToProducerLoop : ";
+ consumerLoopToProducerLoop.print(llvm::dbgs());
+ });
+
+ std::set<unsigned> candidates;
+ for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) {
+ AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>();
+ if (!dimExpr)
+ continue;
+ unsigned position = dimExpr.getPosition();
+ if (fusableLoops.count(position))
+ candidates.insert(position);
}
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t candidates :";
+ for (unsigned i : candidates)
+ llvm::dbgs() << " " << i;
+ llvm::dbgs() << "\n";
+ });
+ if (candidates.empty())
+ return {};
+ std::swap(candidates, fusableLoops);
}
return fusableLoops;
}
-/// Find all dependences that are fusable.
-FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
- ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
+/// Find all dependences that are to be fusable.
+static FusableOpDependencesTy
+findAllFusableDependences(LinalgOp op,
+ const LinalgDependenceGraph &dependenceGraph,
+ const LinalgFusionOptions &fusionOptions) {
FusableOpDependencesTy fusableDependences;
// TODO: Currently fusion would not be legal if the fusable dependence is to
// the same producer but
diff erent indexing map in the consumer. Fix this, but
// in the meanwhile disallow such a fusion.
DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
- for (LinalgOp op : reverse(ops)) {
- for (auto operandIndex :
- llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
- Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
- fusableDependence =
- findFusableProducer(op, operandIndex, dependenceGraph);
- if (!fusableDependence)
- continue;
- LinalgOp producerOp =
- cast<LinalgOp>(fusableDependence->dependentOpView.op);
- // Do not fuse dependences that are to operations not in the same basic
- // block. This avoid moving fused operations across loops that might
- // themselves carry dependency making the fusion illegal.
- if (producerOp.getOperation()->getBlock() !=
- op.getOperation()->getBlock()) {
- op.emitRemark("unhandled fusion of ops in
diff erent basic blocks");
- return FusableOpDependencesTy{};
- }
- // Make sure that the indexing map of the view used for fusion in the
- // producer is a projected permutation.
- unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
- AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
- if (!producerMap.isProjectedPermutation()) {
- op.emitRemark(
- "unhandled non permutation indexing map for fused view in "
- "producer for operand at index ")
- << operandIndex;
- return FusableOpDependencesTy{};
- }
-
- unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
- AffineMap consumerMap = op.getIndexingMap(consumerIdx);
- if (!consumerMap.isProjectedPermutation()) {
- op.emitRemark(
- "unhandled case where indexing map for fused view in the consumer "
- "is "
- "not a projected permuration while fusing at index ")
- << operandIndex;
- return FusableOpDependencesTy{};
- }
+ for (auto operandIndex : fusionOptions.indicesToFuse) {
+ auto fusableDependence =
+ findFusableProducer(op, operandIndex, dependenceGraph);
+ if (!fusableDependence)
+ return FusableOpDependencesTy{};
+ LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+ // Do not fuse dependences that are to operations not in the same basic
+ // block. This avoid moving fused operations across loops that might
+ // themselves carry dependency making the fusion illegal.
+ if (producerOp.getOperation()->getBlock() !=
+ op.getOperation()->getBlock()) {
+ op.emitRemark("unhandled fusion of ops in
diff erent basic blocks");
+ return FusableOpDependencesTy{};
+ }
+ // Make sure that the indexing map of the view used for fusion in the
+ // producer is a projected permutation.
+ unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
+ AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
+ if (!producerMap.isProjectedPermutation()) {
+ op.emitRemark("unhandled non permutation indexing map for fused view in "
+ "producer for operand at index ")
+ << operandIndex;
+ return FusableOpDependencesTy{};
+ }
- // Check if the producer is already a fusion candidate. Cannot fuse this
- // dependence if it has a
diff erent indexing map when used in the
- // consumer.
- if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
- fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
- op.emitRemark(
- "unhandled fusion to the same producer but with
diff erent "
- "indexing maps");
- return FusableOpDependencesTy{};
- }
- fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
+ unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
+ AffineMap consumerMap = op.getIndexingMap(consumerIdx);
+ if (!consumerMap.isProjectedPermutation()) {
+ op.emitRemark(
+ "unhandled case where indexing map for fused view in the consumer is "
+ "not a projected permutation while fusing at index ")
+ << operandIndex;
+ return FusableOpDependencesTy{};
+ }
- fusableDependences[producerOp.getOperation()].push_back(
- *fusableDependence);
+ // Check if the producer is already a fusion candidate. Cannot fuse this
+ // dependence if it has a
diff erent indexing map when used in the consumer.
+ if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
+ fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
+ op.emitRemark("unhandled fusion to the same producer but with
diff erent "
+ "indexing maps");
+ return FusableOpDependencesTy{};
}
+ fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
+
+ fusableDependences[producerOp.getOperation()].push_back(*fusableDependence);
}
return fusableDependences;
}
-/// Tile the fused loops in the root operation, by setting the tile sizes for
-/// all other loops to zero (those will be tiled later).
-static Optional<TiledLinalgOp> tileRootOperation(
- OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
- const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
- SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
- auto zero = std_constant_index(0);
- for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
- if (!fusedLoops.count(i))
- tileSizes[i] = zero;
- LinalgTilingOptions tileFusedLoopsOptions = options;
- tileFusedLoopsOptions.setTileSizes(tileSizes);
- return tileLinalgOp(builder, op, tileFusedLoopsOptions);
-}
-
-/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
-/// to be a tiled operation such that it is valid to fuse all operations in
-/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
-/// `tiledOp`.
-static SmallVector<LinalgOp, 1>
-fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
- ArrayRef<LinalgOp> fusionCandidates,
- const FusableOpDependencesTy &fusableDependences,
- const std::set<unsigned> &fusedLoops) {
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPoint(tiledOp);
- DenseMap<unsigned, Range> fusedLoopsAndRanges;
- for (unsigned loop : fusedLoops) {
- ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop);
- fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
- builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
- }
-
- SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
- for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
- LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
- fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
- builder.setInsertionPoint(fusedOp);
- }
- return fusedOps;
+static bool isZero(Value v) {
+ if (auto cst = v.getDefiningOp<ConstantIndexOp>())
+ return cst.getValue() == 0;
+ return false;
}
template <typename LoopType>
static Optional<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
+tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions) {
- if (ops.empty())
- return llvm::None;
- LinalgOp rootOp = ops.back();
- for (auto op : enumerate(ops)) {
- // TODO: Nothing in the fusion of sequence of ops is specific to
- // buffers. This check can be removed after it is tested on tensors.
- LinalgOp linalgOp = op.value();
- if (!linalgOp.hasBufferSemantics()) {
- linalgOp.emitError("tile and fuse only tested for buffer operation");
- return llvm::None;
- }
- }
- // TODO: Support interchange with tile + fuse. This might actually help do
- // better fusion.
+ const LinalgTilingOptions &tilingOptions,
+ const LinalgFusionOptions &fusionOptions) {
+ assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+ // Some of the tiling options might not be supportable with tile and fuse.
+ // TODO: Support interchange with tile + fuse.
if (!tilingOptions.interchangeVector.empty()) {
- rootOp.emitError("unable to handle tile and fuse with interchange");
+ op.emitError("unable to handle tile and fuse with interchange");
return llvm::None;
}
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPoint(rootOp);
- ScopedContext scope(builder, rootOp.getLoc());
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+ ScopedContext scope(rewriter, op.getLoc());
// Find all the producers.
FusableOpDependencesTy fusableDependences =
- findAllFusableDependences(ops, dependenceGraph);
+ findAllFusableDependences(op, dependenceGraph, fusionOptions);
if (fusableDependences.empty())
return llvm::None;
+ // Enforce the convention that "tiling by zero" skips tiling a particular
+ // dimension. This convention is significantly simpler to handle instead of
+ // adjusting affine maps to account for missing dimensions.
+ auto nLoops = op.getNumLoops();
+ SmallVector<Value, 4> tileSizeVector =
+ tilingOptions.tileSizeComputationFunction(rewriter, op);
+ if (tileSizeVector.size() < nLoops) {
+ auto zero = std_constant_index(0);
+ tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
+ }
+
TiledAndFusedLinalgOps ret;
+
// Find the loops that can be tiled and fused.
- ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
+ std::set<unsigned> tileFuseLoops =
+ collectTileAndFuseLoops(op, fusableDependences);
// If there are no fusable dependences or there are no tile+fusable loops,
// just return.
- if (ret.fusedLoopDims.empty()) {
+ if (tileFuseLoops.empty()) {
return llvm::None;
}
- // Tile the fused loops in the last operation in the list.
- SmallVector<Value, 4> tileSizeVector =
- tilingOptions.tileSizeComputationFunction(builder, rootOp);
- Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
- builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
- if (!tiledRootOp) {
- rootOp.emitError("failed to tile the fused loops");
+ // Get the tile sizes for the first and second tiling steps. For the first
+ // step the tile size are set to zero for the loops that arent
+ // fused. Similarly for the second step, the tile sizes are set to zero for
+ // the loops that are fused. For example, if for the following input
+ //
+ // ```
+ // linalg.add ins(%a, %b) outs(%c)
+ // linalg.matmul ins(%d, %c) outs(%e)
+ // ```
+ //
+ // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
+ // respectively, and since only `j` can be tiled and fused. The tile sizes
+ // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
+ // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
+ // the tiled matmul generated by the first tiling step.
+ SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
+ for (auto tileSize : enumerate(tileSizeVector)) {
+ auto zero = std_constant_index(0);
+ if (tileFuseLoops.count(tileSize.index())) {
+ tileAndFuseSizes.push_back(tileSize.value());
+ tileSizes.push_back(zero);
+ } else {
+ tileSizes.push_back(tileSize.value());
+ tileAndFuseSizes.push_back(zero);
+ }
+ }
+
+ // Tile for the loops that can be fused.
+ LinalgTilingOptions firstTilingOptions = tilingOptions;
+ firstTilingOptions.setTileSizes(tileAndFuseSizes);
+ Optional<TiledLinalgOp> firstTiledOp =
+ tileLinalgOp(rewriter, op, firstTilingOptions);
+ if (!firstTiledOp)
return llvm::None;
+ ret.op = firstTiledOp->op;
+ ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
+
+ rewriter.setInsertionPoint(ret.op);
+ // Fuse the operands.
+ for (auto dependence : fusableDependences) {
+ LinalgOp producerOp = cast<LinalgOp>(dependence.first);
+ unsigned producerIdx =
+ dependence.second.front().dependentOpView.operandIndex;
+ unsigned consumerIdx =
+ dependence.second.front().indexingOpView.operandIndex;
+ LinalgOp fusedOp = fuse(rewriter, producerOp,
+ producerOp.getOutputIndex(producerIdx).getValue(),
+ ret.op, consumerIdx);
+ ret.fusedProducers.push_back(fusedOp);
+ ret.originalProducers.push_back(producerOp);
+ }
+
+ if (!llvm::all_of(tileSizes, isZero)) {
+ // Tile the remaining loops of the root operation.
+ LinalgTilingOptions secondTilingOptions = tilingOptions;
+ // The distribution is done only for the tile+fused loops.
+ secondTilingOptions.distribution = llvm::None;
+ secondTilingOptions.setTileSizes(tileSizes);
+ Optional<TiledLinalgOp> secondTiledOp =
+ tileLinalgOp(rewriter, ret.op, secondTilingOptions);
+ if (!secondTiledOp)
+ return llvm::None;
+ ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
+ secondTiledOp->loops.end());
+ rewriter.eraseOp(ret.op);
+ ret.op = secondTiledOp->op;
}
- ret.op = tiledRootOp->op;
- ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
- // Fuse the other operations into the fused inter-tile loops produced above.
- ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
- fusableDependences, ret.fusedLoopDims);
return ret;
}
Optional<TiledAndFusedLinalgOps>
-mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
+mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions) {
+ const LinalgTilingOptions &tilingOptions,
+ const LinalgFusionOptions &fusionOptions) {
switch (tilingOptions.loopType) {
case LinalgTilingLoopType::Loops:
- return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
- tilingOptions);
+ return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
+ tilingOptions, fusionOptions);
case LinalgTilingLoopType::ParallelLoops:
return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
- builder, ops, dependenceGraph, tilingOptions);
+ rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
default:;
}
return llvm::None;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index a824f6eb620f..e002336ed1c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -166,6 +166,11 @@ struct LinalgOpInstancePromotionOptions {
/// Alignment of promoted buffer.
Optional<unsigned> alignment;
};
+
+struct PromotionInfo {
+ Value fullLocalView;
+ Value partialLocalView;
+};
} // namespace
LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
@@ -228,10 +233,10 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
// To account for general boundary effects, padding must be performed on the
// boundary tiles. For now this is done with an unconditional `fill` op followed
// by a partial `copy` op.
-Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
- OpBuilder &b, Location loc, SubViewOp subView,
- AllocBufferCallbackFn allocationFn, OperationFolder *folder) {
- ScopedContext scopedContext(b, loc);
+static Optional<PromotionInfo>
+promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView,
+ LinalgOpInstancePromotionOptions const &options,
+ OperationFolder *folder) {
auto viewType = subView.getType();
auto rank = viewType.getRank();
SmallVector<Value, 4> fullSizes, partialSizes;
@@ -249,7 +254,8 @@ Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
// If a callback is not specified, then use the default implementation for
// allocating the promoted buffer.
- Optional<Value> fullLocalView = allocationFn(b, subView, fullSizes, folder);
+ Optional<Value> fullLocalView =
+ options.allocationFn(b, subView, fullSizes, folder);
if (!fullLocalView)
return {};
auto zero = folded_std_constant_index(folder, 0);
@@ -273,8 +279,8 @@ promoteSubViews(OpBuilder &b, Location loc,
for (auto v : options.subViews) {
SubViewOp subView = cast<SubViewOp>(v.second.getDefiningOp());
- Optional<PromotionInfo> promotionInfo = promoteSubviewAsNewBuffer(
- b, loc, subView, options.allocationFn, folder);
+ Optional<PromotionInfo> promotionInfo =
+ promoteSubviewAsNewBuffer(b, loc, subView, options, folder);
if (!promotionInfo)
return {};
promotionInfoMap[v.first] = *promotionInfo;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index a855c07cb8d4..836cc28e0a47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -165,69 +165,17 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
if (!linalgOp.hasBufferSemantics())
return failure();
- DenseSet<Operation *> producers;
- producers.insert(linalgOp);
- for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) {
- if (!fusionOptions.indicesToFuse.count(
- dependence.indexingOpView.operandIndex))
- continue;
- if (isa<LinalgOp>(dependence.dependentOpView.op))
- producers.insert(dependence.dependentOpView.op);
- }
-
- SmallVector<LinalgOp, 1> fusionOps;
- for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
- ++it) {
- auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
- if (producerLinalgOp && producers.count(producerLinalgOp))
- fusionOps.push_back(producerLinalgOp);
- }
- fusionOps.push_back(linalgOp);
-
- SmallVector<Value, 4> tileSizes =
- tilingOptions.tileSizeComputationFunction(rewriter, op);
- LinalgTilingOptions instanceTilingOptions = tilingOptions;
- instanceTilingOptions.setTileSizes(tileSizes);
Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
- rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
+ rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
if (!tiledAndFusedOps)
return failure();
-
- // Tile the unfused loops;
- SmallVector<Value, 4> unfusedLoopTileSizes;
- Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0);
- for (auto tileSize : enumerate(tileSizes)) {
- if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
- unfusedLoopTileSizes.push_back(zero);
- else
- unfusedLoopTileSizes.push_back(tileSize.value());
- }
- // Tile the loop only if there is a non-zero tile size.
- if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
- unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
- if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
- if (auto cst = val.getDefiningOp<ConstantIndexOp>())
- return cst.getValue() != 0;
- return true;
- })) {
- LinalgTilingOptions unfusedTilingOptions = tilingOptions;
- unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
- Optional<TiledLinalgOp> unfusedTiledOp =
- tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
- if (!unfusedTiledOp)
- return failure();
- rewriter.eraseOp(tiledAndFusedOps->op);
- tiledAndFusedOps->op = unfusedTiledOp->op;
- }
-
marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
}
- for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
+ for (auto origProducerOp : tiledAndFusedOps->originalProducers)
originalOpMarker.replaceLinalgMarker(rewriter,
origProducerOp.getOperation());
- }
rewriter.updateRootInPlace(
op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
return success();
diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
index fa471811ef4e..2ddc66651db2 100644
--- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
@@ -47,9 +47,7 @@ module {
// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
// CHECK: %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]]
-// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
-// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
-// CHECK: linalg.fill(%[[SV3_2]], %[[CST]])
+// CHECK: linalg.fill(%[[SV3]], %[[CST]])
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
@@ -111,12 +109,9 @@ module {
// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]]
// CHECK: %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]]
// CHECK-SAME: [%[[M]], %[[TILE_N_2]]]
-// CHECK: %[[K_2:.+]] = dim %[[ARG1]], %[[C0]]
// CHECK: %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]]
-// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
-// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][0, %[[IV0]]]
-// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
-// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
+// CHECK-SAME: [%[[K]], %[[TILE_N]]]
+// CHECK: linalg.copy(%[[SV3]], %[[SV1]])
// CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_producer"
// CHECK-NOT: linalg.fill
// CHECK-DAG: %[[M_2:.+]] = dim %[[ARG0]], %[[C0]]
@@ -191,16 +186,11 @@ module {
// CHECK: %[[N:.+]] = dim %[[ARG3]], %[[C1]]
// CHECK: %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
-// CHECK: %[[SV2_2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N]]]
-// CHECK: %[[K_2:.+]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
-// CHECK: %[[SV3_2:.+]] = subview %[[ARG1]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
-// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
+// CHECK-SAME: [%[[TILE_M]], %[[K]]]
+// CHECK: linalg.copy(%[[SV3]], %[[SV1]])
// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer"
-// CHECK: linalg.fill(%[[SV2_2]], %[[CST]])
+// CHECK: linalg.fill(%[[SV2]], %[[CST]])
// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer"
// CHECK-DAG: %[[N_2:.+]] = dim %[[ARG2]], %[[C1]]
// CHECK: scf.parallel (%[[IV1:.+]]) =
@@ -271,18 +261,15 @@ module {
// CHECK: %[[N:.+]] = dim %[[ARG4]], %[[C1]]
// CHECK: %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
-// CHECK: %[[K2_2:.+]] = dim %[[ARG1]], %[[C1]]
// CHECK: %[[K1:.+]] = dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[K1]]]
-// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]]
-// CHECK: %[[SV1_2:.+]] = subview %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K2_2]]]
+// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]]
// CHECK: linalg.matmul
// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
// CHECK-SAME: ins(%[[SV3]], %[[SV4]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME: outs(%[[SV1_2]] : memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME: outs(%[[SV1]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK-DAG: %[[N_2:.+]] = dim %[[ARG3]], %[[C1]]
// CHECK: scf.parallel (%[[IV1:.+]]) =
// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
deleted file mode 100644
index a02c878ef341..000000000000
--- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir
+++ /dev/null
@@ -1,133 +0,0 @@
-// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
-
-module {
- func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?xf32>, %arg3 : memref<?x?xf32>) {
- %cst = constant 0.000000e+00 : f32
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %d0 = dim %arg0, %c0 : memref<?x?xf32>
- %d1 = dim %arg1, %c1 : memref<?x?xf32>
- %0 = alloc(%d0, %d1) : memref<?x?xf32>
- linalg.fill(%0, %cst) : memref<?x?xf32>, f32
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%0 : memref<?x?xf32>)
- linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%0, %arg2 : memref<?x?xf32>, memref<?xf32>)
- outs(%arg3 : memref<?x?xf32>) {
- ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) :
- %5 = addf %arg4, %arg5 : f32
- linalg.yield %5 : f32
- }
- return
- }
-}
-
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK: func @three_op_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK: %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
-// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
-// CHECK-DAG: %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]]
-// CHECK-DAG: %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]]
-// CHECK-DAG: %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]]
-// CHECK-DAG: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-DAG: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]]
-// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}})
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
-// CHECK-SAME: outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>)
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
-// CHECK-SAME: outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
-// CHECK: scf.yield
-// CHECK: }
-
-// -----
-
-module {
- func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
- %arg4: memref<?x?xf32>) {
- %cst = constant 0.000000e+00 : f32
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %m = dim %arg0, %c0 : memref<?x?xf32>
- %n1 = dim %arg1, %c1 : memref<?x?xf32>
- %n2 = dim %arg2, %c1 : memref<?x?xf32>
- %n3 = dim %arg3, %c1 : memref<?x?xf32>
- %0 = alloc(%m, %n1) : memref<?x?xf32>
- %1 = alloc(%m, %n2) : memref<?x?xf32>
- linalg.fill(%0, %cst) : memref<?x?xf32>, f32
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%0 : memref<?x?xf32>)
- linalg.fill(%1, %cst) : memref<?x?xf32>, f32
- linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%1 : memref<?x?xf32>)
- linalg.fill(%arg4, %cst) : memref<?x?xf32>, f32
- linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg4 : memref<?x?xf32>)
- return
- }
-}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK: func @sequence_of_matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C16:.+]] = constant 16 : index
-// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[N1:.+]] = dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[N2:.+]] = dim %[[ARG2]], %[[C1]]
-// CHECK: %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]])
-// CHECK: %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]])
-// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
-// CHECK-SAME: step (%[[C16]]) {
-// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-// CHECK: %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
-// CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]]
-// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
-// CHECK: %[[N3:.+]] = dim %[[ARG4]], %[[C1]]
-// CHECK: %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
-// CHECK: %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N3]]]
-// CHECK: %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
-// CHECK: %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]]
-// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]]
-// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]]
-// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
-// CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}})
-// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME: outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}})
-// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME: outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}})
-// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
-// CHECK-SAME: outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: scf.yield
-// CHECK: }
-
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 5289b2d1055f..eb9e3a533138 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -197,44 +197,6 @@ struct TestLinalgGreedyFusion
}
}
};
-
-/// Pass to test tile and fuse of sequence of operations. Intended only for
-/// testing.
-struct TestLinalgTileAndFuseSequencePass
- : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
- TestLinalgTileAndFuseSequencePass() = default;
- TestLinalgTileAndFuseSequencePass(
- const TestLinalgTileAndFuseSequencePass &pass){};
-
- ListOption<int64_t> tileSizes{
- *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
- }
-
- void runOnFunction() override {
- FuncOp funcOp = getOperation();
- auto &blocks = funcOp.getBody().getBlocks();
- if (!llvm::hasSingleElement(blocks)) {
- return;
- }
- SmallVector<LinalgOp, 2> linalgOps =
- llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
- Aliases aliases;
- LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
- OpBuilder builder(funcOp.getContext());
- Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
- builder, linalgOps, dependenceGraph,
- LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
- LinalgTilingLoopType::ParallelLoops));
- if (!tileAndFuseOps)
- return signalPassFailure();
- for (auto op : linalgOps)
- op.erase();
- }
-};
} // namespace
namespace mlir {
@@ -249,12 +211,5 @@ void registerTestLinalgGreedyFusion() {
"test-linalg-greedy-fusion",
"Test Linalg fusion by applying a greedy test transformation.");
}
-void registerTestLinalgTileAndFuseSequencePass() {
- PassRegistration<TestLinalgTileAndFuseSequencePass>
- testTileAndFuseSequencePass(
- "test-linalg-tile-and-fuse",
- "Test Linalg tiling and fusion of a sequence of Linalg operations.");
-}
-
} // namespace test
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index a0e36cf82534..4771b11b20e4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -74,7 +74,6 @@ void registerTestLinalgCodegenStrategy();
void registerTestLinalgFusionTransforms();
void registerTestLinalgGreedyFusion();
void registerTestLinalgHoisting();
-void registerTestLinalgTileAndFuseSequencePass();
void registerTestLinalgTransforms();
void registerTestLivenessPass();
void registerTestLoopFusion();
@@ -142,7 +141,6 @@ void registerTestPasses() {
test::registerTestLinalgFusionTransforms();
test::registerTestLinalgGreedyFusion();
test::registerTestLinalgHoisting();
- test::registerTestLinalgTileAndFuseSequencePass();
test::registerTestLinalgTransforms();
test::registerTestLivenessPass();
test::registerTestLoopFusion();
More information about the Mlir-commits
mailing list