[Mlir-commits] [mlir] 735a07f - Revert "[mlir][Affine] Add support for multi-store producer fusion"
Diego Caballero
llvmlistbot at llvm.org
Wed Jan 20 14:39:01 PST 2021
Author: Diego Caballero
Date: 2021-01-21T00:37:23+02:00
New Revision: 735a07f0478566f6f7c60a8a98eb8884db574113
URL: https://github.com/llvm/llvm-project/commit/735a07f0478566f6f7c60a8a98eb8884db574113
DIFF: https://github.com/llvm/llvm-project/commit/735a07f0478566f6f7c60a8a98eb8884db574113.diff
LOG: Revert "[mlir][Affine] Add support for multi-store producer fusion"
This reverts commit 7dd198852b4db52ae22242dfeda4eccda83aa8b2.
ASAN issue.
Added:
Modified:
mlir/include/mlir/Analysis/AffineStructures.h
mlir/include/mlir/Analysis/Utils.h
mlir/include/mlir/Transforms/LoopFusionUtils.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Analysis/AffineStructures.cpp
mlir/lib/Analysis/Utils.cpp
mlir/lib/Transforms/LoopFusion.cpp
mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
mlir/test/Transforms/loop-fusion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 893d4ea4ff46..fa80db7d4b63 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -234,21 +234,6 @@ class FlatAffineConstraints {
// TODO: add support for non-unit strides.
LogicalResult addAffineForOpDomain(AffineForOp forOp);
- /// Adds constraints (lower and upper bounds) for each loop in the loop nest
- /// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice.
- /// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in
- /// the nest, sorted outer-to-inner. 'operands' contains the bound operands
- /// for a single bound map. All the bound maps will use the same bound
- /// operands. Note that some loops described by a computation slice might not
- /// exist yet in the IR so the Value attached to those dimension identifiers
- /// might be empty. For that reason, this method doesn't perform Value
- /// look-ups to retrieve the dimension identifier positions. Instead, it
- /// assumes the position of the dim identifiers in the constraint system is
- /// the same as the position of the loop in the loop nest.
- LogicalResult addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
- ArrayRef<AffineMap> ubMaps,
- ArrayRef<Value> operands);
-
/// Adds constraints imposed by the `affine.if` operation. These constraints
/// are collected from the IntegerSet attached to the given `affine.if`
/// instance argument (`ifOp`). It is asserted that:
diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index ee6f8095f25e..30b6272181f5 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -83,25 +83,10 @@ struct ComputationSliceState {
// Clears all bounds and operands in slice state.
void clearBounds();
- /// Returns true if the computation slice is empty.
+ /// Return true if the computation slice is empty.
bool isEmpty() const { return ivs.empty(); }
- /// Returns true if the computation slice encloses all the iterations of the
- /// sliced loop nest. Returns false if it does not. Returns llvm::None if it
- /// cannot determine if the slice is maximal or not.
- // TODO: Cache 'isMaximal' so that we don't recompute it when the slice
- // information hasn't changed.
- Optional<bool> isMaximal() const;
-
void dump() const;
-
-private:
- /// Fast check to determine if the computation slice is maximal. Returns true
- /// if each slice dimension maps to an existing dst dimension and both the src
- /// and the dst loops for those dimensions have the same bounds. Returns false
- /// if both the src and the dst loops don't have the same bounds. Returns
- /// llvm::None if none of the above can be proven.
- Optional<bool> isSliceMaximalFastCheck() const;
};
/// Computes the computation slice loop bounds for one loop nest as affine maps
diff --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h
index 10d6b83d022f..eade565e0325 100644
--- a/mlir/include/mlir/Transforms/LoopFusionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h
@@ -50,8 +50,7 @@ struct FusionResult {
// TODO: Generalize utilities so that producer-consumer and sibling fusion
// strategies can be used without the assumptions made in the AffineLoopFusion
// pass.
-class FusionStrategy {
-public:
+struct FusionStrategy {
enum StrategyEnum {
// Generic loop fusion: Arbitrary loops are considered for fusion. No
// assumptions about a specific fusion strategy from AffineLoopFusion pass
@@ -70,34 +69,13 @@ class FusionStrategy {
// implementation in AffineLoopFusion pass are made. See pass for specific
// details.
Sibling
- };
+ } strategy;
- /// Construct a generic or producer-consumer fusion strategy.
- FusionStrategy(StrategyEnum strategy) : strategy(strategy) {
- assert(strategy != Sibling &&
- "Sibling fusion strategy requires a specific memref");
- }
-
- /// Construct a sibling fusion strategy targeting 'memref'. This construct
- /// should only be used for sibling fusion.
- FusionStrategy(Value memref) : strategy(Sibling), memref(memref) {}
-
- /// Returns the fusion strategy.
- StrategyEnum getStrategy() const { return strategy; };
-
- /// Returns the memref attached to this sibling fusion strategy.
- Value getSiblingFusionMemRef() const {
- assert(strategy == Sibling && "Memref is only valid for sibling fusion");
- return memref;
- }
-
-private:
- /// Fusion strategy.
- StrategyEnum strategy;
-
- /// Target memref for this fusion transformation. Only used for sibling
- /// fusion.
+ // Target memref for this fusion transformation.
Value memref;
+
+ FusionStrategy(StrategyEnum strategy, Value memref)
+ : strategy(strategy), memref(memref) {}
};
/// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the
@@ -108,10 +86,11 @@ class FusionStrategy {
/// NOTE: This function is not feature complete and should only be used in
/// testing.
/// TODO: Update comments when this function is fully implemented.
-FusionResult
-canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth,
- ComputationSliceState *srcSlice,
- FusionStrategy fusionStrategy = FusionStrategy::Generic);
+FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
+ unsigned dstLoopDepth,
+ ComputationSliceState *srcSlice,
+ FusionStrategy fusionStrategy = {
+ FusionStrategy::Generic, Value()});
/// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
/// and source slice loop bounds specified in 'srcSlice'.
@@ -155,12 +134,6 @@ bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
const ComputationSliceState &slice,
int64_t *computeCost);
-/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
-/// producer-consumer dependence between write ops in 'srcOps' and read ops in
-/// 'dstOps'.
-void gatherProducerConsumerMemrefs(ArrayRef<Operation *> srcOps,
- ArrayRef<Operation *> dstOps,
- DenseSet<Value> &producerConsumerMemrefs);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a03b439af339..438a468673b5 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -17,111 +17,6 @@ include "mlir/Pass/PassBase.td"
def AffineLoopFusion : FunctionPass<"affine-loop-fusion"> {
let summary = "Fuse affine loop nests";
- let description = [{
- This pass performs fusion of loop nests using a slicing-based approach. It
- combines two fusion strategies: producer-consumer fusion and sibling fusion.
- Producer-consumer fusion is aimed at fusing pairs of loops where the first
- one writes to a memref that the second reads. Sibling fusion targets pairs
- of loops that share no dependences between them but that load from the same
- memref. The fused loop nests, when possible, are rewritten to access
- significantly smaller local buffers instead of the original memref's, and
- the latter are often either completely optimized away or contracted. This
- transformation leads to enhanced locality and lower memory footprint through
- the elimination or contraction of temporaries/intermediate memref's. These
- benefits are sometimes achieved at the expense of redundant computation
- through a cost model that evaluates available choices such as the depth at
- which a source slice should be materialized in the designation slice.
-
- Example 1: Producer-consumer fusion.
- Input:
- ```mlir
- func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
- %0 = alloc() : memref<10xf32>
- %1 = alloc() : memref<10xf32>
- %cst = constant 0.000000e+00 : f32
- affine.for %arg2 = 0 to 10 {
- affine.store %cst, %0[%arg2] : memref<10xf32>
- affine.store %cst, %1[%arg2] : memref<10xf32>
- }
- affine.for %arg2 = 0 to 10 {
- %2 = affine.load %0[%arg2] : memref<10xf32>
- %3 = addf %2, %2 : f32
- affine.store %3, %arg0[%arg2] : memref<10xf32>
- }
- affine.for %arg2 = 0 to 10 {
- %2 = affine.load %1[%arg2] : memref<10xf32>
- %3 = mulf %2, %2 : f32
- affine.store %3, %arg1[%arg2] : memref<10xf32>
- }
- return
- }
- ```
- Output:
- ```mlir
- func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
- %0 = alloc() : memref<1xf32>
- %1 = alloc() : memref<1xf32>
- %cst = constant 0.000000e+00 : f32
- affine.for %arg2 = 0 to 10 {
- affine.store %cst, %0[0] : memref<1xf32>
- affine.store %cst, %1[0] : memref<1xf32>
- %2 = affine.load %1[0] : memref<1xf32>
- %3 = mulf %2, %2 : f32
- affine.store %3, %arg1[%arg2] : memref<10xf32>
- %4 = affine.load %0[0] : memref<1xf32>
- %5 = addf %4, %4 : f32
- affine.store %5, %arg0[%arg2] : memref<10xf32>
- }
- return
- }
- ```
-
- Example 2: Sibling fusion.
- Input:
- ```mlir
- func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>,
- %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>,
- %arg4: memref<10x10xf32>) {
- affine.for %arg5 = 0 to 3 {
- affine.for %arg6 = 0 to 3 {
- %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
- %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32>
- %2 = mulf %0, %1 : f32
- affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32>
- }
- }
- affine.for %arg5 = 0 to 3 {
- affine.for %arg6 = 0 to 3 {
- %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
- %1 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32>
- %2 = addf %0, %1 : f32
- affine.store %2, %arg4[%arg5, %arg6] : memref<10x10xf32>
- }
- }
- return
- }
- ```
- Output:
- ```mlir
- func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>,
- %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>,
- %arg4: memref<10x10xf32>) {
- affine.for %arg5 = 0 to 3 {
- affine.for %arg6 = 0 to 3 {
- %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
- %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32>
- %2 = mulf %0, %1 : f32
- affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32>
- %3 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
- %4 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32>
- %5 = addf %3, %4 : f32
- affine.store %5, %arg4[%arg5, %arg6] : memref<10x10xf32>
- }
- }
- return
- }
- ```
- }];
let constructor = "mlir::createLoopFusionPass()";
let options = [
Option<"computeToleranceThreshold", "fusion-compute-tolerance", "double",
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 81dc7855184e..12c90fbcfc54 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -708,70 +708,6 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
/*eq=*/false, /*lower=*/false);
}
-/// Adds constraints (lower and upper bounds) for each loop in the loop nest
-/// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice.
-/// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in
-/// the nest, sorted outer-to-inner. 'operands' contains the bound operands
-/// for a single bound map. All the bound maps will use the same bound
-/// operands. Note that some loops described by a computation slice might not
-/// exist yet in the IR so the Value attached to those dimension identifiers
-/// might be empty. For that reason, this method doesn't perform Value
-/// look-ups to retrieve the dimension identifier positions. Instead, it
-/// assumes the position of the dim identifiers in the constraint system is
-/// the same as the position of the loop in the loop nest.
-LogicalResult
-FlatAffineConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
- ArrayRef<AffineMap> ubMaps,
- ArrayRef<Value> operands) {
- assert(lbMaps.size() == ubMaps.size());
- assert(lbMaps.size() <= getNumDimIds());
-
- for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
- AffineMap lbMap = lbMaps[i];
- AffineMap ubMap = ubMaps[i];
- assert(!lbMap || lbMap.getNumInputs() == operands.size());
- assert(!ubMap || ubMap.getNumInputs() == operands.size());
-
- // Check if this slice is just an equality along this dimension. If so,
- // retrieve the existing loop it equates to and add it to the system.
- if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
- ubMap.getNumResults() == 1 &&
- lbMap.getResult(0) + 1 == ubMap.getResult(0) &&
- // The condition above will be true for maps describing a single
- // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
- // Make sure we skip those cases by checking that the lb result is not
- // just a constant.
- !lbMap.getResult(0).isa<AffineConstantExpr>()) {
- // Limited support: we expect the lb result to be just a loop dimension.
- // Not supported otherwise for now.
- AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
- if (!result)
- return failure();
-
- AffineForOp loop =
- getForInductionVarOwner(operands[result.getPosition()]);
- if (!loop)
- return failure();
-
- if (failed(addAffineForOpDomain(loop)))
- return failure();
- continue;
- }
-
- // This slice refers to a loop that doesn't exist in the IR yet. Add its
- // bounds to the system assuming its dimension identifier position is the
- // same as the position of the loop in the loop nest.
- if (lbMap && failed(addLowerOrUpperBound(i, lbMap, operands, /*eq=*/false,
- /*lower=*/true)))
- return failure();
-
- if (ubMap && failed(addLowerOrUpperBound(i, ubMap, operands, /*eq=*/false,
- /*lower=*/false)))
- return failure();
- }
- return success();
-}
-
void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
// Create the base constraints from the integer set attached to ifOp.
FlatAffineConstraints cst(ifOp.getIntegerSet());
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 383a6587bbef..a1e7d1ffe844 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -12,8 +12,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Utils.h"
+
#include "mlir/Analysis/AffineAnalysis.h"
-#include "mlir/Analysis/PresburgerSet.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -127,128 +127,6 @@ void ComputationSliceState::dump() const {
}
}
-/// Fast check to determine if the computation slice is maximal. Returns true if
-/// each slice dimension maps to an existing dst dimension and both the src
-/// and the dst loops for those dimensions have the same bounds. Returns false
-/// if both the src and the dst loops don't have the same bounds. Returns
-/// llvm::None if none of the above can be proven.
-Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
- assert(lbs.size() == ubs.size() && lbs.size() && ivs.size() &&
- "Unexpected number of lbs, ubs and ivs in slice");
-
- for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
- AffineMap lbMap = lbs[i];
- AffineMap ubMap = ubs[i];
-
- // Check if this slice is just an equality along this dimension.
- if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
- ubMap.getNumResults() != 1 ||
- lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
- // The condition above will be true for maps describing a single
- // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
- // Make sure we skip those cases by checking that the lb result is not
- // just a constant.
- lbMap.getResult(0).isa<AffineConstantExpr>())
- return llvm::None;
-
- // Limited support: we expect the lb result to be just a loop dimension for
- // now.
- AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
- if (!result)
- return llvm::None;
-
- // Retrieve dst loop bounds.
- AffineForOp dstLoop =
- getForInductionVarOwner(lbOperands[i][result.getPosition()]);
- if (!dstLoop)
- return llvm::None;
- AffineMap dstLbMap = dstLoop.getLowerBoundMap();
- AffineMap dstUbMap = dstLoop.getUpperBoundMap();
-
- // Retrieve src loop bounds.
- AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
- assert(srcLoop && "Expected affine for");
- AffineMap srcLbMap = srcLoop.getLowerBoundMap();
- AffineMap srcUbMap = srcLoop.getUpperBoundMap();
-
- // Limited support: we expect simple src and dst loops with a single
- // constant component per bound for now.
- if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
- dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
- return llvm::None;
-
- AffineExpr srcLbResult = srcLbMap.getResult(0);
- AffineExpr dstLbResult = dstLbMap.getResult(0);
- AffineExpr srcUbResult = srcUbMap.getResult(0);
- AffineExpr dstUbResult = dstUbMap.getResult(0);
- if (!srcLbResult.isa<AffineConstantExpr>() ||
- !srcUbResult.isa<AffineConstantExpr>() ||
- !dstLbResult.isa<AffineConstantExpr>() ||
- !dstUbResult.isa<AffineConstantExpr>())
- return llvm::None;
-
- // Check if src and dst loop bounds are the same. If not, we can guarantee
- // that the slice is not maximal.
- if (srcLbResult != dstLbResult || srcUbResult != dstUbResult)
- return false;
- }
-
- return true;
-}
-
-/// Returns true if the computation slice encloses all the iterations of the
-/// sliced loop nest. Returns false if it does not. Returns llvm::None if it
-/// cannot determine if the slice is maximal or not.
-Optional<bool> ComputationSliceState::isMaximal() const {
- // Fast check to determine if the computation slice is maximal. If the result
- // is inconclusive, we proceed with a more expensive analysis.
- Optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
- if (isMaximalFastCheck.hasValue())
- return isMaximalFastCheck;
-
- // Create constraints for the src loop nest being sliced.
- FlatAffineConstraints srcConstraints;
- srcConstraints.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0,
- /*numLocals=*/0, ivs);
- for (Value iv : ivs) {
- AffineForOp loop = getForInductionVarOwner(iv);
- assert(loop && "Expected affine for");
- if (failed(srcConstraints.addAffineForOpDomain(loop)))
- return llvm::None;
- }
-
- // Create constraints for the slice using the dst loop nest information. We
- // retrieve existing dst loops from the lbOperands.
- SmallVector<Value, 8> consumerIVs;
- for (Value lbOp : lbOperands[0])
- if (getForInductionVarOwner(lbOp))
- consumerIVs.push_back(lbOp);
-
- // Add empty IV Values for those new loops that are not equalities and,
- // therefore, are not yet materialized in the IR.
- for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
- consumerIVs.push_back(Value());
-
- FlatAffineConstraints sliceConstraints;
- sliceConstraints.reset(/*numDims=*/consumerIVs.size(), /*numSymbols=*/0,
- /*numLocals=*/0, consumerIVs);
-
- if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
- return llvm::None;
-
- if (srcConstraints.getNumDimIds() != sliceConstraints.getNumDimIds())
- // Constraint dims are
diff erent. The integer set
diff erence can't be
- // computed so we don't know if the slice is maximal.
- return llvm::None;
-
- // Compute the
diff erence between the src loop nest and the slice integer
- // sets.
- PresburgerSet srcSet(srcConstraints);
- PresburgerSet sliceSet(sliceConstraints);
- PresburgerSet
diff Set = srcSet.subtract(sliceSet);
- return
diff Set.isIntegerEmpty();
-}
-
unsigned MemRefRegion::getRank() const {
return memref.getType().cast<MemRefType>().getRank();
}
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 6c56368ca6e1..6fe112b89baf 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -30,7 +30,6 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <iomanip>
-#include <set>
#include <sstream>
#define DEBUG_TYPE "affine-loop-fusion"
@@ -271,6 +270,64 @@ struct MemRefDependenceGraph {
return false;
}
+ // Returns the unique AffineWriteOpInterface in `node` that meets all the
+ // following:
+ // *) store is the only one that writes to a function-local memref live out
+ // of `node`,
+ // *) store is not the source of a self-dependence on `node`.
+ // Otherwise, returns a null AffineWriteOpInterface.
+ AffineWriteOpInterface getUniqueOutgoingStore(Node *node) {
+ AffineWriteOpInterface uniqueStore;
+
+ // Return null if `node` doesn't have any outgoing edges.
+ auto outEdgeIt = outEdges.find(node->id);
+ if (outEdgeIt == outEdges.end())
+ return nullptr;
+
+ const auto &nodeOutEdges = outEdgeIt->second;
+ for (auto *op : node->stores) {
+ auto storeOp = cast<AffineWriteOpInterface>(op);
+ auto memref = storeOp.getMemRef();
+ // Skip this store if there are no dependences on its memref. This means
+ // that store either:
+ // *) writes to a memref that is only read within the same loop nest
+ // (self-dependence edges are not represented in graph at the moment),
+ // *) writes to a function live out memref (function parameter), or
+ // *) is dead.
+ if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) {
+ return (edge.value != memref);
+ }))
+ continue;
+
+ if (uniqueStore)
+ // Found multiple stores to function-local live-out memrefs.
+ return nullptr;
+ // Found first store to function-local live-out memref.
+ uniqueStore = storeOp;
+ }
+
+ return uniqueStore;
+ }
+
+ // Returns true if node 'id' can be removed from the graph. Returns false
+ // otherwise. A node can be removed from the graph iff the following
+ // conditions are met:
+ // *) The node does not write to any memref which escapes (or is a
+ // function/block argument).
+ // *) The node has no successors in the dependence graph.
+ bool canRemoveNode(unsigned id) {
+ if (writesToLiveInOrEscapingMemrefs(id))
+ return false;
+ Node *node = getNode(id);
+ for (auto *storeOpInst : node->stores) {
+ // Return false if there exist out edges from 'id' on 'memref'.
+ auto storeMemref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+ if (getOutEdgeCount(id, storeMemref) > 0)
+ return false;
+ }
+ return true;
+ }
+
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
@@ -438,49 +495,42 @@ struct MemRefDependenceGraph {
return dstNodeInst;
}
- // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
- // taking into account that:
- // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
- // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
- // private memref.
- void updateEdges(unsigned srcId, unsigned dstId,
- const DenseSet<Value> &privateMemRefs, bool removeSrcId) {
+ // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
+ // has been replaced in node at 'dstId' by a private memref depending
+ // on the value of 'createPrivateMemRef'.
+ void updateEdges(unsigned srcId, unsigned dstId, Value oldMemRef,
+ bool createPrivateMemRef) {
// For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
if (inEdges.count(srcId) > 0) {
SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
for (auto &inEdge : oldInEdges) {
- // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
- if (privateMemRefs.count(inEdge.value) == 0)
+ // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
+ if (inEdge.value != oldMemRef)
addEdge(inEdge.id, dstId, inEdge.value);
}
}
// For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
- // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
if (outEdges.count(srcId) > 0) {
SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
for (auto &outEdge : oldOutEdges) {
// Remove any out edges from 'srcId' to 'dstId' across memrefs.
if (outEdge.id == dstId)
removeEdge(srcId, outEdge.id, outEdge.value);
- else if (removeSrcId) {
- addEdge(dstId, outEdge.id, outEdge.value);
- removeEdge(srcId, outEdge.id, outEdge.value);
- }
}
}
// Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
// replaced by a private memref). These edges could come from nodes
// other than 'srcId' which were removed in the previous step.
- if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
+ if (inEdges.count(dstId) > 0 && createPrivateMemRef) {
SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
for (auto &inEdge : oldInEdges)
- if (privateMemRefs.count(inEdge.value) > 0)
+ if (inEdge.value == oldMemRef)
removeEdge(inEdge.id, dstId, inEdge.value);
}
}
// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
- // of sibling node 'sibId' into node 'dstId'.
+ // of sibling node 'sidId' into node 'dstId'.
void updateEdges(unsigned sibId, unsigned dstId) {
// For each edge in 'inEdges[sibId]':
// *) Add new edge from source node 'inEdge.id' to 'dstNode'.
@@ -574,132 +624,6 @@ struct MemRefDependenceGraph {
void dump() const { print(llvm::errs()); }
};
-/// Returns true if node 'srcId' can be removed after fusing it with node
-/// 'dstId'. The node can be removed if any of the following conditions are met:
-/// 1. 'srcId' has no output dependences after fusion and no escaping memrefs.
-/// 2. 'srcId' has no output dependences after fusion, has escaping memrefs
-/// and the fusion slice is maximal.
-/// 3. 'srcId' has output dependences after fusion, the fusion slice is
-/// maximal and the fusion insertion point dominates all the dependences.
-static bool canRemoveSrcNodeAfterFusion(
- unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
- Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
- MemRefDependenceGraph *mdg) {
-
- Operation *dstNodeOp = mdg->getNode(dstId)->op;
- bool hasOutDepsAfterFusion = false;
-
- for (auto &outEdge : mdg->outEdges[srcId]) {
- Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
- // Skip dependence with dstOp since it will be removed after fusion.
- if (depNodeOp == dstNodeOp)
- continue;
-
- // Only fusion within the same block is supported. Use domination analysis
- // when needed.
- if (depNodeOp->getBlock() != dstNodeOp->getBlock())
- return false;
-
- // Check if the insertion point of the fused loop dominates the dependence.
- // Otherwise, the src loop can't be removed.
- if (fusedLoopInsPoint != depNodeOp &&
- !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
- LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
- "dominate dependence\n");
- return false;
- }
-
- hasOutDepsAfterFusion = true;
- }
-
- // If src loop has dependences after fusion or it writes to an live-out or
- // escaping memref, we can only remove it if the fusion slice is maximal so
- // that all the dependences are preserved.
- if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
- Optional<bool> isMaximal = fusionSlice.isMaximal();
- if (!isMaximal.hasValue()) {
- LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
- "if fusion is maximal\n");
- return false;
- }
-
- if (!isMaximal.getValue()) {
- LLVM_DEBUG(llvm::dbgs()
- << "Src loop can't be removed: fusion is not maximal\n");
- return false;
- }
- }
-
- return true;
-}
-
-/// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
-/// 'dstId'.
-// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
-static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
- DenseSet<unsigned> &srcIdCandidates) {
- // Skip if no input edges along which to fuse.
- if (mdg->inEdges.count(dstId) == 0)
- return;
-
- // Gather memrefs from loads in 'dstId'.
- auto *dstNode = mdg->getNode(dstId);
- DenseSet<Value> consumedMemrefs;
- for (Operation *load : dstNode->loads)
- consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
-
- // Traverse 'dstId' incoming edges and gather the nodes that contain a store
- // to one of the consumed memrefs.
- for (auto &srcEdge : mdg->inEdges[dstId]) {
- auto *srcNode = mdg->getNode(srcEdge.id);
- // Skip if 'srcNode' is not a loop nest.
- if (!isa<AffineForOp>(srcNode->op))
- continue;
-
- if (any_of(srcNode->stores, [&](Operation *op) {
- auto storeOp = cast<AffineWriteOpInterface>(op);
- return consumedMemrefs.count(storeOp.getMemRef()) > 0;
- }))
- srcIdCandidates.insert(srcNode->id);
- }
-}
-
-/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
-/// producer-consumer dependence between 'srcId' and 'dstId'.
-static void
-gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
- MemRefDependenceGraph *mdg,
- DenseSet<Value> &producerConsumerMemrefs) {
- auto *dstNode = mdg->getNode(dstId);
- auto *srcNode = mdg->getNode(srcId);
- gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
- producerConsumerMemrefs);
-}
-
-/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
-/// that escape the function. A memref escapes the function if either:
-/// 1. It's a function argument, or
-/// 2. It's used by a non-affine op (e.g., std load/store, std call, etc.)
-void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
- DenseSet<Value> &escapingMemRefs) {
- auto *node = mdg->getNode(id);
- for (auto *storeOpInst : node->stores) {
- auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
- if (escapingMemRefs.count(memref))
- continue;
- // Check if 'memref' escapes because it's a block argument.
- if (memref.isa<BlockArgument>()) {
- escapingMemRefs.insert(memref);
- continue;
- }
- // Check if 'memref' escapes through a non-affine op (e.g., std load/store,
- // call op, etc.).
- for (Operation *user : memref.getUsers())
- if (!isMemRefDereferencingOp(*user))
- escapingMemRefs.insert(memref);
- }
-}
-
} // end anonymous namespace
// Initializes the data dependence graph by walking operations in 'f'.
@@ -707,7 +631,6 @@ void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
// TODO: Add support for taking a Block arg to construct the
// dependence graph at a
diff erent depth.
bool MemRefDependenceGraph::init(FuncOp f) {
- LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
DenseMap<Value, SetVector<unsigned>> memrefAccesses;
// TODO: support multi-block functions.
@@ -763,12 +686,6 @@ bool MemRefDependenceGraph::init(FuncOp f) {
}
}
-#ifndef NDEBUG
- for (auto &idAndNode : nodes)
- LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
- << *(idAndNode.second.op) << "\n");
-#endif
-
// Add dependence edges between nodes which produce SSA values and their
// users.
for (auto &idAndNode : nodes) {
@@ -808,6 +725,22 @@ bool MemRefDependenceGraph::init(FuncOp f) {
return true;
}
+// Removes load operations from 'srcLoads' which operate on 'memref', and
+// adds them to 'dstLoads'.
+static void moveLoadsAccessingMemrefTo(Value memref,
+ SmallVectorImpl<Operation *> *srcLoads,
+ SmallVectorImpl<Operation *> *dstLoads) {
+ dstLoads->clear();
+ SmallVector<Operation *, 4> srcLoadsToKeep;
+ for (auto *load : *srcLoads) {
+ if (cast<AffineReadOpInterface>(load).getMemRef() == memref)
+ dstLoads->push_back(load);
+ else
+ srcLoadsToKeep.push_back(load);
+ }
+ srcLoads->swap(srcLoadsToKeep);
+}
+
// Sinks all sequential loops to the innermost levels (while preserving
// relative order among them) and moves all parallel loops to the
// outermost (while again preserving relative order among them).
@@ -999,6 +932,75 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
return false;
}
+// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
+// may write to multiple memrefs but it is required that only one of them,
+// 'srcLiveOutStoreOp', has output edges.
+// Returns true if 'dstNode's read/write region to 'memref' is a super set of
+// 'srcNode's write region to 'memref' and 'srcId' has only one output edge.
+// TODO: Generalize this to handle more live in/out cases.
+static bool
+canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
+ AffineWriteOpInterface srcLiveOutStoreOp,
+ MemRefDependenceGraph *mdg) {
+ assert(srcLiveOutStoreOp && "Expected a valid store op");
+ auto *dstNode = mdg->getNode(dstId);
+ Value memref = srcLiveOutStoreOp.getMemRef();
+ // Return false if 'srcNode' has more than one output edge on 'memref'.
+ if (mdg->getOutEdgeCount(srcId, memref) > 1)
+ return false;
+
+ // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'.
+ MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc());
+ if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute MemRefRegion for source operation\n.");
+ return false;
+ }
+ SmallVector<int64_t, 4> srcShape;
+ // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
+ // by 'srcStoreOp' at depth 'dstLoopDepth'.
+ Optional<int64_t> srcNumElements =
+ srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
+ if (!srcNumElements.hasValue())
+ return false;
+
+ // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'.
+ // TODO: Compute 'unionboundingbox' of all write regions (one for
+ // each store op in 'dstStoreOps').
+ SmallVector<Operation *, 2> dstStoreOps;
+ dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
+ SmallVector<Operation *, 2> dstLoadOps;
+ dstNode->getLoadOpsForMemref(memref, &dstLoadOps);
+
+ auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0];
+ MemRefRegion dstRegion(dstOpInst->getLoc());
+ if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute MemRefRegion for dest operation\n.");
+ return false;
+ }
+ SmallVector<int64_t, 4> dstShape;
+ // Query 'dstRegion' for 'dstShape' and 'dstNumElements'.
+ // by 'dstOpInst' at depth 'dstLoopDepth'.
+ Optional<int64_t> dstNumElements =
+ dstRegion.getConstantBoundingSizeAndShape(&dstShape);
+ if (!dstNumElements.hasValue())
+ return false;
+
+ // Return false if write region is not a superset of 'srcNodes' write
+ // region to 'memref'.
+ // TODO: Check the shape and lower bounds here too.
+ if (srcNumElements != dstNumElements)
+ return false;
+
+ // Return false if 'memref' is used by a non-affine operation that is
+ // between node 'srcId' and node 'dstId'.
+ if (hasNonAffineUsersOnThePath(srcId, dstId, mdg))
+ return false;
+
+ return true;
+}
+
// Checks the profitability of fusing a backwards slice of the loop nest
// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
@@ -1027,6 +1029,9 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
// the largest computation slice at the maximal dst loop depth (closest to
// the load) to minimize reuse distance and potentially enable subsequent
// load/store forwarding.
+// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
+// the same memref as is written by 'srcOpInst', then the union of slice
+// loop bounds is used to compute the slice and associated slice cost.
// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
// nest, at which the src computation slice is inserted/fused.
// NOTE: We attempt to maximize the dst loop depth, but there are cases
@@ -1036,18 +1041,18 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
// *) Compares the total cost of the unfused loop nests to the min cost fused
// loop nest computed in the previous step, and returns true if the latter
// is lower.
-// TODO: Extend profitability analysis to support scenarios with multiple
-// stores.
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
- AffineForOp dstForOp,
+ ArrayRef<Operation *> dstLoadOpInsts,
ArrayRef<ComputationSliceState> depthSliceUnions,
unsigned maxLegalFusionDepth,
unsigned *dstLoopDepth,
double computeToleranceThreshold) {
LLVM_DEBUG({
llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
- llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
- llvm::dbgs() << dstForOp << "\n";
+ llvm::dbgs() << ' ' << *srcOpInst << " and destination op(s)\n";
+ for (auto dstOpInst : dstLoadOpInsts) {
+ llvm::dbgs() << " " << *dstOpInst << "\n";
+ };
});
if (maxLegalFusionDepth == 0) {
@@ -1065,8 +1070,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
return false;
// Compute cost of dst loop nest.
+ SmallVector<AffineForOp, 4> dstLoopIVs;
+ getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
+
LoopNestStats dstLoopNestStats;
- if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
+ if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
return false;
// Search for min cost value for 'dstLoopDepth'. At each value of
@@ -1100,19 +1108,18 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
// Compute op instance count for the src loop nest.
- uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
+ uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats);
// Evaluate all depth choices for materializing the slice in the destination
// loop nest.
for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
- const ComputationSliceState &slice = depthSliceUnions[i - 1];
// Skip slice union if it wasn't computed for this depth.
- if (slice.isEmpty())
+ if (depthSliceUnions[i - 1].isEmpty())
continue;
int64_t fusedLoopNestComputeCost;
- if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
- dstLoopNestStats, slice,
+ if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
+ dstLoopNestStats, depthSliceUnions[i - 1],
&fusedLoopNestComputeCost)) {
LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
continue;
@@ -1124,11 +1131,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
1;
// Determine what the slice write MemRefRegion would be, if the src loop
- // nest slice 'slice' were to be inserted into the dst loop nest at loop
- // depth 'i'.
+ // nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst
+ // loop nest at loop depth 'i'.
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
- &slice))) {
+ &depthSliceUnions[i - 1]))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to compute slice write region at loopDepth: " << i
<< "\n");
@@ -1211,7 +1218,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
<< "\n fused loop nest compute cost: "
<< minFusedLoopNestComputeCost << "\n");
- auto dstMemSize = getMemoryFootprintBytes(dstForOp);
+ auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
Optional<double> storageReduction = None;
@@ -1315,6 +1322,8 @@ struct GreedyFusion {
MemRefDependenceGraph *mdg;
// Worklist of graph nodes visited during the fusion pass.
SmallVector<unsigned, 8> worklist;
+ // Set of graph nodes which are present on the worklist.
+ llvm::SmallDenseSet<unsigned, 16> worklistSet;
// Parameter for local buffer size threshold.
unsigned localBufSizeThreshold;
// Parameter for fast memory space.
@@ -1335,14 +1344,16 @@ struct GreedyFusion {
fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
computeToleranceThreshold(computeToleranceThreshold) {}
- /// Initializes 'worklist' with nodes from 'mdg'.
+ // Initializes 'worklist' with nodes from 'mdg'
void init() {
// TODO: Add a priority queue for prioritizing nodes by
diff erent
// metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
worklist.clear();
+ worklistSet.clear();
for (auto &idAndNode : mdg->nodes) {
const Node &node = idAndNode.second;
worklist.push_back(node.id);
+ worklistSet.insert(node.id);
}
}
@@ -1361,11 +1372,11 @@ struct GreedyFusion {
}
void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
- LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
init();
while (!worklist.empty()) {
unsigned dstId = worklist.back();
worklist.pop_back();
+ worklistSet.erase(dstId);
// Skip if this node was removed (fused into another node).
if (mdg->nodes.count(dstId) == 0)
@@ -1375,97 +1386,114 @@ struct GreedyFusion {
// Skip if 'dstNode' is not a loop nest.
if (!isa<AffineForOp>(dstNode->op))
continue;
-
- LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
-
// Sink sequential loops in 'dstNode' (and thus raise parallel loops)
// while preserving relative order. This can increase the maximum loop
// depth at which we can fuse a slice of a producer loop nest into a
// consumer loop nest.
sinkSequentialLoops(dstNode);
- auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
-
- // Try to fuse 'dstNode' with candidate producer loops until a fixed point
- // is reached. Fusing two loops may expose new fusion opportunities.
- bool dstNodeChanged;
- do {
- // Gather src loop candidates for 'dstNode' and visit them in "quasi"
- // reverse program order to minimize the number of iterations needed to
- // reach the fixed point. Note that this is a best effort approach since
- // 'getProducerCandidates' does not always guarantee that program order
- // in 'srcIdCandidates'.
- dstNodeChanged = false;
- DenseSet<unsigned> srcIdCandidates;
- getProducerCandidates(dstId, mdg, srcIdCandidates);
-
- /// Visit candidates in reverse node id order. This order corresponds to
- /// the reverse program order when the 'mdg' is created. However,
- /// reverse program order is not guaranteed and must not be required.
- /// Reverse program order won't be held if the 'mdg' is reused from a
- /// previous fusion step or if the node creation order changes in the
- /// future to support more advance cases.
- SmallVector<unsigned, 16> sortedSrcIdCandidates;
- sortedSrcIdCandidates.reserve(srcIdCandidates.size());
- sortedSrcIdCandidates.append(srcIdCandidates.begin(),
- srcIdCandidates.end());
- llvm::sort(sortedSrcIdCandidates, std::greater<unsigned>());
-
- for (unsigned srcId : sortedSrcIdCandidates) {
+
+ SmallVector<Operation *, 4> loads = dstNode->loads;
+ SmallVector<Operation *, 4> dstLoadOpInsts;
+ DenseSet<Value> visitedMemrefs;
+ while (!loads.empty()) {
+ // Get memref of load on top of the stack.
+ auto memref = cast<AffineReadOpInterface>(loads.back()).getMemRef();
+ if (visitedMemrefs.count(memref) > 0)
+ continue;
+ visitedMemrefs.insert(memref);
+ // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
+ moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
+ // Skip if no input edges along which to fuse.
+ if (mdg->inEdges.count(dstId) == 0)
+ continue;
+ // Iterate through in-edges for 'dstId' and src node id for any
+ // edges on 'memref'.
+ SmallVector<unsigned, 2> srcNodeIds;
+ for (auto &srcEdge : mdg->inEdges[dstId]) {
+ // Skip 'srcEdge' if not for 'memref'.
+ if (srcEdge.value != memref)
+ continue;
+ srcNodeIds.push_back(srcEdge.id);
+ }
+ for (unsigned srcId : srcNodeIds) {
+ // Skip if this node was removed (fused into another node).
+ if (mdg->nodes.count(srcId) == 0)
+ continue;
// Get 'srcNode' from which to attempt fusion into 'dstNode'.
auto *srcNode = mdg->getNode(srcId);
- auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
- LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
- << " for dst loop " << dstId << "\n");
-
- DenseSet<Value> producerConsumerMemrefs;
- gatherProducerConsumerMemrefs(srcId, dstId, mdg,
- producerConsumerMemrefs);
-
- // Skip if 'srcNode' out edge count on any memref is greater than
- // 'maxSrcUserCount'.
- if (any_of(producerConsumerMemrefs, [&](Value memref) {
- return mdg->getOutEdgeCount(srcNode->id, memref) >
- maxSrcUserCount;
- }))
+ // Skip if 'srcNode' is not a loop nest.
+ if (!isa<AffineForOp>(srcNode->op))
continue;
+ // Skip if 'srcNode' has more than one live-out store to a
+ // function-local memref.
+ // TODO: Support more generic multi-output src loop nests
+ // fusion.
+ auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode);
+ if (!srcStoreOp) {
+ // Get the src store op at the deepest loop depth.
+ // We will use 'LoopFusionUtils::canFuseLoops' to check fusion
+ // feasibility for loops with multiple stores.
+ unsigned maxLoopDepth = 0;
+ for (auto *op : srcNode->stores) {
+ auto storeOp = cast<AffineWriteOpInterface>(op);
+ if (storeOp.getMemRef() != memref) {
+ srcStoreOp = nullptr;
+ break;
+ }
+ unsigned loopDepth = getNestingDepth(storeOp);
+ if (loopDepth > maxLoopDepth) {
+ maxLoopDepth = loopDepth;
+ srcStoreOp = storeOp;
+ }
+ }
+ if (!srcStoreOp)
+ continue;
+ }
- // Gather memrefs in 'srcNode' that are written and escape to the
- // function (e.g., memref function arguments, returned memrefs,
- // memrefs passed to function calls, etc.).
- DenseSet<Value> srcEscapingMemRefs;
- gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
-
- // Skip if there are non-affine operations in between the 'srcNode'
- // and 'dstNode' using their memrefs. If so, we wouldn't be able to
- // compute a legal insertion point for now. 'srcNode' and 'dstNode'
- // memrefs with non-affine operation users would be considered
- // escaping memrefs so we can limit this check to only scenarios with
- // escaping memrefs.
- if (!srcEscapingMemRefs.empty() &&
- hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
- LLVM_DEBUG(
- llvm::dbgs()
- << "Can't fuse: non-affine users in between the loops\n.");
+ // Unique outgoing store found must write to 'memref' since 'memref'
+ // is the one that established the producer-consumer relationship
+ // between 'srcNode' and 'dstNode'.
+ assert(srcStoreOp.getMemRef() == memref &&
+ "Found store to unexpected memref");
+
+ // Skip if 'srcNode' writes to any live in or escaping memrefs,
+ // and cannot be fused.
+ bool writesToLiveInOrOut =
+ mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
+ if (writesToLiveInOrOut &&
+ !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg))
continue;
+
+ // Don't create a private memref if 'writesToLiveInOrOut'.
+ bool createPrivateMemref = !writesToLiveInOrOut;
+ // Don't create a private memref if 'srcNode' has in edges on
+ // 'memref', or if 'dstNode' has out edges on 'memref'.
+ if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 ||
+ mdg->getOutEdgeCount(dstNode->id, memref) > 0) {
+ createPrivateMemref = false;
}
+ // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
+ if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
+ continue;
+
// Compute an operation list insertion point for the fused loop
// nest which preserves dependences.
- Operation *fusedLoopInsPoint =
+ Operation *insertPointInst =
mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
- if (fusedLoopInsPoint == nullptr)
+ if (insertPointInst == nullptr)
continue;
- // Compute the innermost common loop depth for dstNode
- // producer-consumer loads/stores.
+ auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
+ auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
+
+ // Compute the innermost common loop depth for dstNode loads/stores.
SmallVector<Operation *, 2> dstMemrefOps;
for (Operation *op : dstNode->loads)
- if (producerConsumerMemrefs.count(
- cast<AffineReadOpInterface>(op).getMemRef()) > 0)
+ if (cast<AffineReadOpInterface>(op).getMemRef() == memref)
dstMemrefOps.push_back(op);
for (Operation *op : dstNode->stores)
- if (producerConsumerMemrefs.count(
- cast<AffineWriteOpInterface>(op).getMemRef()))
+ if (cast<AffineWriteOpInterface>(op).getMemRef() == memref)
dstMemrefOps.push_back(op);
unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
@@ -1474,7 +1502,7 @@ struct GreedyFusion {
unsigned maxLegalFusionDepth = 0;
SmallVector<ComputationSliceState, 8> depthSliceUnions;
depthSliceUnions.resize(dstLoopDepthTest);
- FusionStrategy strategy(FusionStrategy::ProducerConsumer);
+ FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref);
for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
FusionResult result = mlir::canFuseLoops(
srcAffineForOp, dstAffineForOp,
@@ -1484,82 +1512,27 @@ struct GreedyFusion {
maxLegalFusionDepth = i;
}
- if (maxLegalFusionDepth == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Can't fuse: fusion is not legal at any depth\n");
+ // Skip if fusion is not feasible at any loop depths.
+ if (maxLegalFusionDepth == 0)
continue;
- }
// Check if fusion would be profitable. We skip profitability analysis
// for maximal fusion since we already know the maximal legal depth to
// fuse.
unsigned bestDstLoopDepth = maxLegalFusionDepth;
- if (!maximalFusion) {
- // Retrieve producer stores from the src loop.
- SmallVector<Operation *, 2> producerStores;
- for (Operation *op : srcNode->stores)
- if (producerConsumerMemrefs.count(
- cast<AffineWriteOpInterface>(op).getMemRef()))
- producerStores.push_back(op);
-
- // TODO: Suppport multiple producer stores in profitability
- // analysis. We limit profitability analysis to only scenarios with
- // a single producer store for now. Note that some multi-store
- // producer scenarios will still go through profitability analysis
- // if only one of the stores is involved the producer-consumer
- // relationship of the candidate loops.
- assert(producerStores.size() > 0 && "Expected producer store");
- if (producerStores.size() > 1)
- LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
- "supported for this case\n");
- else if (!isFusionProfitable(producerStores[0], producerStores[0],
- dstAffineForOp, depthSliceUnions,
- maxLegalFusionDepth, &bestDstLoopDepth,
- computeToleranceThreshold))
- continue;
- }
+ if (!maximalFusion &&
+ !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
+ depthSliceUnions, maxLegalFusionDepth,
+ &bestDstLoopDepth, computeToleranceThreshold))
+ continue;
assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
- ComputationSliceState &bestSlice =
- depthSliceUnions[bestDstLoopDepth - 1];
- assert(!bestSlice.isEmpty() && "Missing slice union for depth");
-
- // Determine if 'srcId' can be removed after fusion, taking into
- // account remaining dependences, escaping memrefs and the fusion
- // insertion point.
- bool removeSrcNode = canRemoveSrcNodeAfterFusion(
- srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
- mdg);
-
- DenseSet<Value> privateMemrefs;
- for (Value memref : producerConsumerMemrefs) {
- // Don't create a private memref if 'srcNode' writes to escaping
- // memrefs.
- if (srcEscapingMemRefs.count(memref) > 0)
- continue;
-
- // Don't create a private memref if 'srcNode' has in edges on
- // 'memref' or 'dstNode' has out edges on 'memref'.
- if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 ||
- mdg->getOutEdgeCount(dstId, memref) > 0)
- continue;
-
- // If 'srcNode' will be removed but it has out edges on 'memref' to
- // nodes other than 'dstNode', we have to preserve dependences and
- // cannot create a private memref.
- if (removeSrcNode &&
- any_of(mdg->outEdges[srcId], [&](const auto &edge) {
- return edge.value == memref && edge.id != dstId;
- }))
- continue;
-
- // Create a private version of this memref.
- privateMemrefs.insert(memref);
- }
+ assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
+ "Missing slice union for depth");
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
- fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
- dstNodeChanged = true;
+ fuseLoops(srcAffineForOp, dstAffineForOp,
+ depthSliceUnions[bestDstLoopDepth - 1]);
LLVM_DEBUG(llvm::dbgs()
<< "Fused src loop " << srcId << " into dst loop " << dstId
@@ -1567,20 +1540,18 @@ struct GreedyFusion {
<< dstAffineForOp << "\n");
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
- if (fusedLoopInsPoint != dstAffineForOp.getOperation())
- dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint);
+ if (insertPointInst != dstAffineForOp.getOperation())
+ dstAffineForOp->moveBefore(insertPointInst);
// Update edges between 'srcNode' and 'dstNode'.
- mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
- removeSrcNode);
+ mdg->updateEdges(srcNode->id, dstNode->id, memref,
+ createPrivateMemref);
// Collect slice loop stats.
LoopNestStateCollector dstForCollector;
dstForCollector.collect(dstAffineForOp);
- for (Value memref : privateMemrefs) {
+ if (createPrivateMemref) {
// Create private memref for 'memref' in 'dstAffineForOp'.
- // TODO: remove storesForMemref and move the code below to the
- // loop-if.
SmallVector<Operation *, 4> storesForMemref;
for (auto *storeOpInst : dstForCollector.storeOpInsts) {
if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
@@ -1592,6 +1563,7 @@ struct GreedyFusion {
auto newMemRef = createPrivateMemRef(
dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
fastMemorySpace, localBufSizeThreshold);
+ visitedMemrefs.insert(newMemRef);
// Create new node in dependence graph for 'newMemRef' alloc op.
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
// Add edge from 'newMemRef' node to dstNode.
@@ -1602,21 +1574,58 @@ struct GreedyFusion {
LoopNestStateCollector dstLoopCollector;
dstLoopCollector.collect(dstAffineForOp.getOperation());
+ // Add new load ops to current Node load op list 'loads' to continue
+ // fusing based on new operands.
+ for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
+ // NOTE: Change 'loads' to a hash set in case efficiency is an
+ // issue. We still use a vector since it's expected to be small.
+ if (!llvm::is_contained(loads, loadOpInst))
+ loads.push_back(loadOpInst);
+ }
+ // Clear visited memrefs after fusion so that previously visited src
+ // nodes are considered for fusion again in the context of the new
+ // fused node.
+ // TODO: This shouldn't be necessary if we visited candidates in the
+ // dependence graph in post-order or once we fully support multi-store
+ // producers. Currently, in a multi-store producer scenario such as
+ // A->B, A->C, B->C, we fail to fuse A+B due to the multiple outgoing
+ // edges. However, after fusing B+C, A has a single outgoing edge and
+ // can be fused if we revisit it in the context of the new fused B+C
+ // node.
+ visitedMemrefs.clear();
+
// Clear and add back loads and stores.
mdg->clearNodeLoadAndStores(dstNode->id);
mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
dstLoopCollector.storeOpInsts);
-
- if (removeSrcNode) {
- LLVM_DEBUG(llvm::dbgs()
- << "Removing src loop " << srcId << " after fusion\n");
- // srcNode is no longer valid after it is removed from mdg.
- srcAffineForOp.erase();
- mdg->removeNode(srcId);
- srcNode = nullptr;
+ // Remove old src loop nest if it no longer has outgoing dependence
+ // edges, and if it does not write to a memref which escapes the
+ // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has been
+ // fused into 'dstNode' and write region of 'dstNode' covers the write
+ // region of 'srcNode', and 'srcNode' has no other users so it is safe
+ // to remove.
+ if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
+ mdg->removeNode(srcNode->id);
+ srcNode->op->erase();
+ } else {
+ // Add remaining users of 'oldMemRef' back on the worklist (if not
+ // already there), as its replacement with a local/private memref
+ // has reduced dependences on 'oldMemRef' which may have created new
+ // fusion opportunities.
+ if (mdg->outEdges.count(srcNode->id) > 0) {
+ SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
+ mdg->outEdges[srcNode->id];
+ for (auto &outEdge : oldOutEdges) {
+ if (outEdge.value == memref &&
+ worklistSet.count(outEdge.id) == 0) {
+ worklist.push_back(outEdge.id);
+ worklistSet.insert(outEdge.id);
+ }
+ }
+ }
}
}
- } while (dstNodeChanged);
+ }
}
}
@@ -1627,6 +1636,7 @@ struct GreedyFusion {
while (!worklist.empty()) {
unsigned dstId = worklist.back();
worklist.pop_back();
+ worklistSet.erase(dstId);
// Skip if this node was removed (fused into another node).
if (mdg->nodes.count(dstId) == 0)
@@ -1688,7 +1698,7 @@ struct GreedyFusion {
SmallVector<ComputationSliceState, 8> depthSliceUnions;
depthSliceUnions.resize(dstLoopDepthTest);
unsigned maxLegalFusionDepth = 0;
- FusionStrategy strategy(memref);
+ FusionStrategy strategy(FusionStrategy::Sibling, memref);
for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
FusionResult result = mlir::canFuseLoops(
sibAffineForOp, dstAffineForOp,
@@ -1702,10 +1712,10 @@ struct GreedyFusion {
if (maxLegalFusionDepth == 0)
continue;
- unsigned bestDstLoopDepth = maxLegalFusionDepth;
+ unsigned bestDstLoopDepth = dstLoopDepthTest;
if (!maximalFusion) {
// Check if fusion would be profitable.
- if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp,
+ if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
depthSliceUnions, maxLegalFusionDepth,
&bestDstLoopDepth, computeToleranceThreshold))
continue;
diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index 9749a8de2351..9759300f2e42 100644
--- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -191,8 +191,11 @@ gatherLoadsAndStores(AffineForOp forOp,
/// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
// TODO: Generalize this check for sibling and more generic fusion scenarios.
// TODO: Support forward slice fusion.
-static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
- ArrayRef<Operation *> dstOps) {
+static unsigned getMaxLoopDepth(ArrayRef<Operation *> dstOps,
+ FusionStrategy fusionStrategy) {
+ assert(fusionStrategy.strategy == FusionStrategy::ProducerConsumer &&
+ "Fusion strategy not supported");
+
if (dstOps.empty())
// Expected at least one memory operation.
// TODO: Revisit this case with a specific example.
@@ -200,14 +203,15 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
// Filter out ops in 'dstOps' that do not use the producer-consumer memref so
// that they are not considered for analysis.
- DenseSet<Value> producerConsumerMemrefs;
- gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
+ // TODO: Currently, we pass the producer-consumer memref through
+ // fusionStrategy. We will retrieve the memrefs from 'srcOps' once we
+ // generalize the algorithm.
SmallVector<Operation *, 4> targetDstOps;
for (Operation *dstOp : dstOps) {
auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
Value memref = loadOp ? loadOp.getMemRef()
: cast<AffineWriteOpInterface>(dstOp).getMemRef();
- if (producerConsumerMemrefs.count(memref) > 0)
+ if (memref == fusionStrategy.memref)
targetDstOps.push_back(dstOp);
}
@@ -304,10 +308,10 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
// loop dependences.
// TODO: Enable this check for sibling and more generic loop fusion
// strategies.
- if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
+ if (fusionStrategy.strategy == FusionStrategy::ProducerConsumer) {
// TODO: 'getMaxLoopDepth' does not support forward slice fusion.
assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
- if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
+ if (getMaxLoopDepth(opsB, fusionStrategy) < dstLoopDepth) {
LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
return FusionResult::FailFusionDependence;
}
@@ -320,7 +324,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
// Filter out ops in 'opsA' to compute the slice union based on the
// assumptions made by the fusion strategy.
SmallVector<Operation *, 4> strategyOpsA;
- switch (fusionStrategy.getStrategy()) {
+ switch (fusionStrategy.strategy) {
case FusionStrategy::Generic:
// Generic fusion. Take into account all the memory operations to compute
// the slice union.
@@ -328,9 +332,10 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
break;
case FusionStrategy::ProducerConsumer:
// Producer-consumer fusion (AffineLoopFusion pass) only takes into
- // account stores in 'srcForOp' to compute the slice union.
+ // account stores to 'memref' in 'srcForOp' to compute the slice union.
for (Operation *op : opsA) {
- if (isa<AffineWriteOpInterface>(op))
+ auto store = dyn_cast<AffineWriteOpInterface>(op);
+ if (store && store.getMemRef() == fusionStrategy.memref)
strategyOpsA.push_back(op);
}
break;
@@ -339,7 +344,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
// to 'memref' in 'srcForOp' to compute the slice union.
for (Operation *op : opsA) {
auto load = dyn_cast<AffineReadOpInterface>(op);
- if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
+ if (load && load.getMemRef() == fusionStrategy.memref)
strategyOpsA.push_back(op);
}
break;
@@ -623,23 +628,3 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
/*tripCountOverrideMap=*/nullptr, &computeCostMap);
return true;
}
-
-/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
-/// producer-consumer dependence between write ops in 'srcOps' and read ops in
-/// 'dstOps'.
-void mlir::gatherProducerConsumerMemrefs(
- ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
- DenseSet<Value> &producerConsumerMemrefs) {
- // Gather memrefs from stores in 'srcOps'.
- DenseSet<Value> srcStoreMemRefs;
- for (Operation *op : srcOps)
- if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
- srcStoreMemRefs.insert(storeOp.getMemRef());
-
- // Compute the intersection between memrefs from stores in 'srcOps' and
- // memrefs from loads in 'dstOps'.
- for (Operation *op : dstOps)
- if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
- if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
- producerConsumerMemrefs.insert(loadOp.getMemRef());
-}
diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index c1bccea4c9f5..a23f0e2ee430 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -364,8 +364,8 @@ func @should_fuse_and_move_to_preserve_war_dep() {
// -----
-// CHECK-LABEL: func @should_fuse_if_top_level_access() {
-func @should_fuse_if_top_level_access() {
+// CHECK-LABEL: func @should_fuse_with_private_memref_if_top_level_access() {
+func @should_fuse_with_private_memref_if_top_level_access() {
%m = alloc() : memref<10xf32>
%cf7 = constant 7.0 : f32
@@ -378,45 +378,14 @@ func @should_fuse_if_top_level_access() {
%c0 = constant 4 : index
%v1 = affine.load %m[%c0] : memref<10xf32>
- // Top-level load to '%m' should prevent creating a private memref but
- // loop nests should be fused and '%i0' should be removed.
- // CHECK: %[[m:.*]] = alloc() : memref<10xf32>
- // CHECK-NOT: alloc
-
- // CHECK: affine.for %[[i1:.*]] = 0 to 10 {
- // CHECK-NEXT: affine.store %{{.*}}, %[[m]][%[[i1]]] : memref<10xf32>
- // CHECK-NEXT: affine.load %[[m]][%[[i1]]] : memref<10xf32>
- // CHECK-NEXT: }
- // CHECK: affine.load %[[m]][%{{.*}}] : memref<10xf32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @should_fuse_but_not_remove_src() {
-func @should_fuse_but_not_remove_src() {
- %m = alloc() : memref<100xf32>
- %cf7 = constant 7.0 : f32
-
- affine.for %i0 = 0 to 100 {
- affine.store %cf7, %m[%i0] : memref<100xf32>
- }
- affine.for %i1 = 0 to 17 {
- %v0 = affine.load %m[%i1] : memref<100xf32>
- }
- %v1 = affine.load %m[99] : memref<100xf32>
-
- // Loop '%i0' and '%i1' should be fused but '%i0' shouldn't be removed to
- // preserve the dependence with the top-level access.
- // CHECK: affine.for %{{.*}} = 0 to 100 {
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<100xf32>
+ // Top-level load to '%{{.*}}' should prevent fusion.
+ // CHECK: affine.for %{{.*}} = 0 to 10 {
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
- // CHECK-NEXT: affine.for %{{.*}} = 0 to 17 {
+ // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
// CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
// CHECK-NEXT: }
- // CHECK-NEXT: affine.load %{{.*}}[99] : memref<100xf32>
- // CHECK-NEXT: return
return
}
@@ -1141,8 +1110,8 @@ func @should_fuse_with_private_memrefs_with_
diff _shapes() {
// -----
-// CHECK-LABEL: func @should_fuse_live_out_arg_but_preserve_src_loop(%{{.*}}: memref<10xf32>) {
-func @should_fuse_live_out_arg_but_preserve_src_loop(%arg0: memref<10xf32>) {
+// CHECK-LABEL: func @should_not_fuse_live_out_arg(%{{.*}}: memref<10xf32>) {
+func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) {
%cf7 = constant 7.0 : f32
affine.for %i0 = 0 to 10 {
@@ -1160,7 +1129,6 @@ func @should_fuse_live_out_arg_but_preserve_src_loop(%arg0: memref<10xf32>) {
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %{{.*}} = 0 to 9 {
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
@@ -1192,8 +1160,8 @@ func @should_fuse_live_out_arg(%arg0: memref<10xf32>) {
// -----
-// CHECK-LABEL: func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32>
-func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32> {
+// CHECK-LABEL: func @should_not_fuse_escaping_memref() -> memref<10xf32>
+func @should_not_fuse_escaping_memref() -> memref<10xf32> {
%cf7 = constant 7.0 : f32
%m = alloc() : memref<10xf32>
affine.for %i0 = 0 to 10 {
@@ -1202,21 +1170,19 @@ func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32> {
affine.for %i1 = 0 to 9 {
%v0 = affine.load %m[%i1] : memref<10xf32>
}
- // This tests that the loop nest '%i0' should not be removed after fusion
- // because it writes to memref '%m', which is returned by the function, and
- // the '%i1' memory region does not cover '%i0' memory region.
-
+ // This tests that the loop nest '%{{.*}}' should not be removed after fusion
+ // because it writes to memref '%{{.*}}' which is returned by the function.
// CHECK-DAG: alloc() : memref<10xf32>
// CHECK: affine.for %{{.*}} = 0 to 10 {
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %{{.*}} = 0 to 9 {
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %{{.*}} : memref<10xf32>
return %m : memref<10xf32>
}
+
// -----
// This should fuse with the %in becoming a 1x1x1.
@@ -1264,7 +1230,7 @@ func @R3_to_R2_reshape() {
// -----
-func @should_fuse_multi_output_producer() {
+func @should_not_fuse_multi_output_producer() {
%a = alloc() : memref<10xf32>
%b = alloc() : memref<10xf32>
@@ -1280,10 +1246,12 @@ func @should_fuse_multi_output_producer() {
}
// CHECK: affine.for %{{.*}} = 0 to 10 {
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
- // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
- // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
+ // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+ // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
return
@@ -1536,8 +1504,8 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32>
// -----
-// CHECK-LABEL: func @should_fuse_only_two_loops_and_remove_producer() {
-func @should_fuse_only_two_loops_and_remove_producer() {
+// CHECK-LABEL: func @should_fuse_after_private_memref_creation() {
+func @should_fuse_after_private_memref_creation() {
%a = alloc() : memref<10xf32>
%b = alloc() : memref<10xf32>
@@ -1557,21 +1525,18 @@ func @should_fuse_only_two_loops_and_remove_producer() {
// On the first visit to '%i2', the fusion algorithm can not fuse loop nest
// '%i0' into '%i2' because of the dependences '%i0' and '%i2' each have on
- // '%i1'. Then, '%i0' is fused into '%i1' and no private memref is created for
- // memref '%a' to be able to remove '%i0' and still preserve the depencence on
- // '%a' with '%i2'.
- // TODO: Alternatively, we could fuse '%i0' into '%i1' with a private memref,
- // the dependence between '%i0' and '%i1' on memref '%a' would no longer exist,
- // and '%i0' could be fused into '%i2' as well. Note that this approach would
- // duplicate the computation in loop nest '%i0' to loop nests '%i1' and '%i2',
- // which would limit its profitability.
+ // '%i1'. However, once the loop nest '%i0' is fused into '%i1' with a
+ // private memref, the dependence between '%i0' and '%i1' on memref '%a' no
+ // longer exists, so '%i0' can now be fused into '%i2'.
+
// CHECK: affine.for %{{.*}} = 0 to 10 {
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+ // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
- // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+ // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
@@ -2255,7 +2220,7 @@ func @affine_2_dependent_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<10
}
}
- // CHECK: affine.for %{{.*}} = 0 to 1024 {
+ // CHECK: affine.for %{{.*}} = 0 to 1024 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
@@ -2346,8 +2311,8 @@ func @should_fuse_function_live_out_multi_store_producer(%live_in_out_m : memref
}
// CHECK: affine.for %[[i0:.*]] = 0 to 10 {
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
- // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
+ // CHECK-NEXT: affine.load %{{.*}}[%[[i0]]] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
return
@@ -2408,11 +2373,12 @@ func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x
// -----
-// Verify that 'fuseProducerConsumerNodes' fuse a producer loop with a store
-// that has multiple outgoing edges.
+// Verify that 'fuseProducerConsumerNodes' doesn't fuse a producer loop with
+// a store that has multiple outgoing edges. Sibling loop fusion should not fuse
+// any of these loops due to dependencies on external memref '%a'.
-// CHECK-LABEL: func @should_fuse_multi_outgoing_edge_store_producer
-func @should_fuse_multi_outgoing_edge_store_producer(%a : memref<1xf32>) {
+// CHECK-LABEL: func @should_not_fuse_multi_outgoing_edge_store_producer1
+func @should_not_fuse_multi_outgoing_edge_store_producer1(%a : memref<1xf32>) {
%cst = constant 0.000000e+00 : f32
affine.for %arg0 = 0 to 1 {
affine.store %cst, %a[%arg0] : memref<1xf32>
@@ -2425,12 +2391,9 @@ func @should_fuse_multi_outgoing_edge_store_producer(%a : memref<1xf32>) {
affine.for %arg0 = 0 to 1 {
%0 = affine.load %a[%arg0] : memref<1xf32>
}
- // CHECK: affine.for %{{.*}} = 0 to 1 {
- // CHECK-NEXT: affine.store
- // CHECK-NEXT: affine.load
- // CHECK-NEXT: affine.load
- // CHECK-NEXT: }
-
+ // CHECK: affine.for %{{.*}} = 0 to 1
+ // CHECK: affine.for %{{.*}} = 0 to 1
+ // CHECK: affine.for %{{.*}} = 0 to 1
return
}
@@ -2700,109 +2663,3 @@ func @fuse_minor_affine_map(%in: memref<128xf32>, %out: memref<20x512xf32>) {
// MAXIMAL: affine.for
// MAXIMAL-NEXT: affine.for
// MAXIMAL-NOT: affine.for
-// MAXIMAL: return
-
-// -----
-
-// CHECK-LABEL: func @should_fuse_multi_store_producer_and_privatize_memfefs
-func @should_fuse_multi_store_producer_and_privatize_memfefs() {
- %a = alloc() : memref<10xf32>
- %b = alloc() : memref<10xf32>
- %c = alloc() : memref<10xf32>
- %cst = constant 0.000000e+00 : f32
- affine.for %arg0 = 0 to 10 {
- affine.store %cst, %a[%arg0] : memref<10xf32>
- affine.store %cst, %b[%arg0] : memref<10xf32>
- affine.store %cst, %c[%arg0] : memref<10xf32>
- %0 = affine.load %c[%arg0] : memref<10xf32>
- }
-
- affine.for %arg0 = 0 to 10 {
- %0 = affine.load %a[%arg0] : memref<10xf32>
- }
-
- affine.for %arg0 = 0 to 10 {
- %0 = affine.load %b[%arg0] : memref<10xf32>
- }
-
- // All the memrefs should be privatized except '%c', which is not involved in
- // the producer-consumer fusion.
- // CHECK: affine.for %{{.*}} = 0 to 10 {
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
- // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
- // CHECK-NEXT: }
-
- return
-}
-
-// -----
-
-func @should_fuse_multi_store_producer_with_scaping_memrefs_and_remove_src(
- %a : memref<10xf32>, %b : memref<10xf32>) {
- %cst = constant 0.000000e+00 : f32
- affine.for %i0 = 0 to 10 {
- affine.store %cst, %a[%i0] : memref<10xf32>
- affine.store %cst, %b[%i0] : memref<10xf32>
- }
-
- affine.for %i1 = 0 to 10 {
- %0 = affine.load %a[%i1] : memref<10xf32>
- }
-
- affine.for %i2 = 0 to 10 {
- %0 = affine.load %b[%i2] : memref<10xf32>
- }
-
- // Producer loop '%i0' should be removed after fusion since fusion is maximal.
- // No memref should be privatized since they escape the function.
- // CHECK: affine.for %{{.*}} = 0 to 10 {
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: }
- // CHECK-NOT: affine.for
-
- return
-}
-
-// -----
-
-func @should_fuse_multi_store_producer_with_scaping_memrefs_and_preserve_src(
- %a : memref<10xf32>, %b : memref<10xf32>) {
- %cst = constant 0.000000e+00 : f32
- affine.for %i0 = 0 to 10 {
- affine.store %cst, %a[%i0] : memref<10xf32>
- affine.store %cst, %b[%i0] : memref<10xf32>
- }
-
- affine.for %i1 = 0 to 5 {
- %0 = affine.load %a[%i1] : memref<10xf32>
- }
-
- affine.for %i2 = 0 to 10 {
- %0 = affine.load %b[%i2] : memref<10xf32>
- }
-
- // Loops '%i0' and '%i2' should be fused first and '%i0' should be removed
- // since fusion is maximal. Then the fused loop and '%i1' should be fused
- // and the fused loop shouldn't be removed since fusion is not maximal.
- // CHECK: affine.for %{{.*}} = 0 to 10 {
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: }
- // CHECK: affine.for %{{.*}} = 0 to 5 {
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
- // CHECK-NEXT: }
- // CHECK-NOT: affine.for
-
- return
-}
More information about the Mlir-commits
mailing list