[Mlir-commits] [mlir] ada5808 - [mlir] Enable cleanup of single iteration reduction loops being sibling-fused maximally

Sumesh Udayakumaran llvmlistbot at llvm.org
Thu Jul 15 14:07:46 PDT 2021


Author: Sumesh Udayakumaran
Date: 2021-07-16T00:07:20+03:00
New Revision: ada580863f8941f8b0426be0d78249f4cfa8f4d5

URL: https://github.com/llvm/llvm-project/commit/ada580863f8941f8b0426be0d78249f4cfa8f4d5
DIFF: https://github.com/llvm/llvm-project/commit/ada580863f8941f8b0426be0d78249f4cfa8f4d5.diff

LOG: [mlir] Enable cleanup of single iteration reduction loops being sibling-fused maximally

Changes include the following:
    1. Single iteration reduction loops being sibling fused at innermost insertion level
     are skipped from being considered as sequential loops.
    Otherwise, the slice bounds of these loops is reset.

    2. Promote loops that are skipped in previous step into outer loops.

    3. Two utility function - buildSliceTripCountMap, getSliceIterationCount - are moved from
mlir/lib/Transforms/Utils/LoopFusionUtils.cpp to mlir/lib/Analysis/Utils.cpp

Reviewed By: bondhugula, vinayaka-polymage

Differential Revision: https://reviews.llvm.org/D104249

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Utils.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
    mlir/include/mlir/Transforms/LoopFusionUtils.h
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Transforms/LoopFusion.cpp
    mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
    mlir/test/Transforms/loop-fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index 06630f401f5e1..89b73dace9b7d 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -49,6 +49,9 @@ void getEnclosingAffineForAndIfOps(Operation &op,
 /// surrounding this operation.
 unsigned getNestingDepth(Operation *op);
 
+/// Returns whether a loop is a parallel loop and contains a reduction loop.
+bool isLoopParallelAndContainsReduction(AffineForOp forOp);
+
 /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
 /// at 'forOp'.
 void getSequentialLoops(AffineForOp forOp,
@@ -184,6 +187,18 @@ void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp,
                               unsigned loopDepth, bool isBackwardSlice,
                               ComputationSliceState *sliceState);
 
+/// Return the number of iterations for the `slicetripCountMap` provided.
+uint64_t getSliceIterationCount(
+    const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap);
+
+/// Builds a map 'tripCountMap' from AffineForOp to constant trip count for
+/// loop nest surrounding represented by slice loop bounds in 'slice'. Returns
+/// true on success, false otherwise (if a non-constant trip count was
+/// encountered).
+bool buildSliceTripCountMap(
+    const ComputationSliceState &slice,
+    llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap);
+
 /// Computes in 'sliceUnion' the union of all slice bounds computed at
 /// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
 /// then verifies if it is valid. The parameter 'numCommonLoops' is the number

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 1199d1bb272e2..a7c399cb0fbeb 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -406,6 +406,20 @@ void buildAffineLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
                          function_ref<void(OpBuilder &, Location, ValueRange)>
                              bodyBuilderFn = nullptr);
 
+/// Replace `loop` with a new loop where `newIterOperands` are appended with
+/// new initialization values and `newYieldedValues` are added as new yielded
+/// values. The returned ForOp has `newYieldedValues.size()` new result values.
+/// Additionally, if `replaceLoopResults` is true, all uses of
+/// `loop.getResults()` are replaced with the first `loop.getNumResults()`
+/// return values  of the original loop respectively. The original loop is
+/// deleted and the new loop returned.
+/// Prerequisite: `newIterOperands.size() == newYieldedValues.size()`.
+AffineForOp replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop,
+                                      ValueRange newIterOperands,
+                                      ValueRange newYieldedValues,
+                                      ValueRange newIterArgs,
+                                      bool replaceLoopResults = true);
+
 /// AffineBound represents a lower or upper bound in the for operation.
 /// This class does not own the underlying operands. Instead, it refers
 /// to the operands stored in the AffineForOp. Its life span should not exceed

diff  --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h
index b66d38fae3b84..87a71b6041dbd 100644
--- a/mlir/include/mlir/Transforms/LoopFusionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h
@@ -114,10 +114,13 @@ canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth,
              ComputationSliceState *srcSlice,
              FusionStrategy fusionStrategy = FusionStrategy::Generic);
 
-/// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
-/// and source slice loop bounds specified in 'srcSlice'.
+/// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion
+/// point and source slice loop bounds specified in 'srcSlice'.
+/// `isInnermostSiblingInsertionFusion` enables cleanup of `srcForOp that is a
+/// single-iteration reduction loop being sibling-fused into a 'dstForOp'.
 void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
-               const ComputationSliceState &srcSlice);
+               const ComputationSliceState &srcSlice,
+               bool isInnermostSiblingInsertionFusion = false);
 
 /// LoopNestStats aggregates various per-loop statistics (eg. loop trip count
 /// and operation count) for a loop nest up until (and including) the innermost

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 221c79343ded4..262a329ab0de3 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Analysis/Utils.h"
 #include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Analysis/PresburgerSet.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
@@ -969,6 +970,73 @@ mlir::computeSliceUnion(ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB,
   return SliceComputationResult::Success;
 }
 
+// TODO: extend this to handle multiple result maps.
+static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
+  assert(lbMap.getNumResults() == 1 && "expected single result bound map");
+  assert(ubMap.getNumResults() == 1 && "expected single result bound map");
+  assert(lbMap.getNumDims() == ubMap.getNumDims());
+  assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
+  AffineExpr lbExpr(lbMap.getResult(0));
+  AffineExpr ubExpr(ubMap.getResult(0));
+  auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
+                                         lbMap.getNumSymbols());
+  auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
+  if (!cExpr)
+    return None;
+  return cExpr.getValue();
+}
+
+// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
+// nest surrounding represented by slice loop bounds in 'slice'. Returns true
+// on success, false otherwise (if a non-constant trip count was encountered).
+// TODO: Make this work with non-unit step loops.
+bool mlir::buildSliceTripCountMap(
+    const ComputationSliceState &slice,
+    llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
+  unsigned numSrcLoopIVs = slice.ivs.size();
+  // Populate map from AffineForOp -> trip count
+  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
+    AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
+    auto *op = forOp.getOperation();
+    AffineMap lbMap = slice.lbs[i];
+    AffineMap ubMap = slice.ubs[i];
+    // If lower or upper bound maps are null or provide no results, it implies
+    // that source loop was not at all sliced, and the entire loop will be a
+    // part of the slice.
+    if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
+        ubMap.getNumResults() == 0) {
+      // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
+      if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
+        (*tripCountMap)[op] =
+            forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
+        continue;
+      }
+      Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
+      if (maybeConstTripCount.hasValue()) {
+        (*tripCountMap)[op] = maybeConstTripCount.getValue();
+        continue;
+      }
+      return false;
+    }
+    Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
+    // Slice bounds are created with a constant ub - lb 
diff erence.
+    if (!tripCount.hasValue())
+      return false;
+    (*tripCountMap)[op] = tripCount.getValue();
+  }
+  return true;
+}
+
+// Return the number of iterations in the given slice.
+uint64_t mlir::getSliceIterationCount(
+    const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
+  uint64_t iterCount = 1;
+  for (const auto &count : sliceTripCountMap) {
+    iterCount *= count.second;
+  }
+  return iterCount;
+}
+
 const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
 // Computes slice bounds by projecting out any loop IVs from
 // 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
@@ -1039,18 +1107,36 @@ void mlir::getComputationSliceState(
     getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
                        &sequentialLoops);
   }
-  // Clear all sliced loop bounds beginning at the first sequential loop, or
-  // first loop with a slice fusion barrier attribute..
-  // TODO: Use MemRef read/write regions instead of
-  // using 'kSliceFusionBarrierAttrName'.
   auto getSliceLoop = [&](unsigned i) {
     return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
   };
+  auto isInnermostInsertion = [&]() {
+    return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
+                            : loopDepth >= dstLoopIVs.size());
+  };
+  llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
+  auto srcIsUnitSlice = [&]() {
+    return (buildSliceTripCountMap(*sliceState, &sliceTripCountMap) &&
+            (getSliceIterationCount(sliceTripCountMap) == 1));
+  };
+  // Clear all sliced loop bounds beginning at the first sequential loop, or
+  // first loop with a slice fusion barrier attribute..
+
   for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
     Value iv = getSliceLoop(i).getInductionVar();
     if (sequentialLoops.count(iv) == 0 &&
         getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
       continue;
+    // Skip reset of bounds of reduction loop inserted in the destination loop
+    // that meets the following conditions:
+    //    1. Slice is  single trip count.
+    //    2. Loop bounds of the source and destination match.
+    //    3. Is being inserted at the innermost insertion point.
+    Optional<bool> isMaximal = sliceState->isMaximal();
+    if (isLoopParallelAndContainsReduction(getSliceLoop(i)) &&
+        isInnermostInsertion() && srcIsUnitSlice() && isMaximal.hasValue() &&
+        isMaximal.getValue())
+      continue;
     for (unsigned j = i; j < numSliceLoopIVs; ++j) {
       sliceState->lbs[j] = AffineMap();
       sliceState->ubs[j] = AffineMap();
@@ -1258,6 +1344,14 @@ Optional<int64_t> mlir::getMemoryFootprintBytes(AffineForOp forOp,
       std::next(Block::iterator(forInst)), memorySpace);
 }
 
+/// Returns whether a loop is parallel and contains a reduction loop.
+bool mlir::isLoopParallelAndContainsReduction(AffineForOp forOp) {
+  SmallVector<LoopReduction> reductions;
+  if (!isLoopParallel(forOp, &reductions))
+    return false;
+  return !reductions.empty();
+}
+
 /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
 /// at 'forOp'.
 void mlir::getSequentialLoops(AffineForOp forOp,

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1d496574639c8..52200f44c02a9 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1883,6 +1883,49 @@ void mlir::buildAffineLoopNest(
                           buildAffineLoopFromValues);
 }
 
+AffineForOp mlir::replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop,
+                                            ValueRange newIterOperands,
+                                            ValueRange newYieldedValues,
+                                            ValueRange newIterArgs,
+                                            bool replaceLoopResults) {
+  assert(newIterOperands.size() == newYieldedValues.size() &&
+         "newIterOperands must be of the same size as newYieldedValues");
+  // Create a new loop before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(loop);
+  auto operands = llvm::to_vector<4>(loop.getIterOperands());
+  operands.append(newIterOperands.begin(), newIterOperands.end());
+  SmallVector<Value, 4> lbOperands(loop.getLowerBoundOperands());
+  SmallVector<Value, 4> ubOperands(loop.getUpperBoundOperands());
+  SmallVector<Value, 4> steps(loop.getStep());
+  auto lbMap = loop.getLowerBoundMap();
+  auto ubMap = loop.getUpperBoundMap();
+  AffineForOp newLoop =
+      b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
+                            loop.getStep(), operands);
+  // Take the body of the original parent loop.
+  newLoop.getLoopBody().takeBody(loop.getLoopBody());
+  for (Value val : newIterArgs)
+    newLoop.getLoopBody().addArgument(val.getType());
+
+  // Update yield operation with new values to be added.
+  if (!newYieldedValues.empty()) {
+    auto yield = cast<AffineYieldOp>(newLoop.getBody()->getTerminator());
+    b.setInsertionPoint(yield);
+    auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
+    yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
+    b.create<AffineYieldOp>(yield.getLoc(), yieldOperands);
+    yield.erase();
+  }
+  if (replaceLoopResults) {
+    for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
+                                                    loop.getNumResults()))) {
+      std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+    }
+  }
+  return newLoop;
+}
+
 //===----------------------------------------------------------------------===//
 // AffineIfOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 667a175cd266b..49bd52d6672d6 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -1682,6 +1682,7 @@ struct GreedyFusion {
   // Visits each node in the graph, and for each node, attempts to fuse it with
   // its sibling nodes (nodes which share a parent, but no dependence edges).
   void fuseSiblingNodes() {
+    LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
     init();
     while (!worklist.empty()) {
       unsigned dstId = worklist.back();
@@ -1773,10 +1774,14 @@ struct GreedyFusion {
       assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
       assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
              "Fusion depth has no computed slice union");
-
+      // Check if source loop is being inserted in the innermost
+      // destination loop. Based on this, the fused loop may be optimized
+      // further inside `fuseLoops`.
+      bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
       // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
       mlir::fuseLoops(sibAffineForOp, dstAffineForOp,
-                      depthSliceUnions[bestDstLoopDepth - 1]);
+                      depthSliceUnions[bestDstLoopDepth - 1],
+                      isInnermostInsertion);
 
       auto dstForInst = cast<AffineForOp>(dstNode->op);
       // Update operation position of fused loop nest (if needed).

diff  --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index 511f6a572f05e..29ee0cffe73d2 100644
--- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Analysis/AffineAnalysis.h"
 #include "mlir/Analysis/AffineStructures.h"
 #include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/IR/AffineExpr.h"
@@ -363,10 +364,74 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   return FusionResult::Success;
 }
 
+/// Patch the loop body of a forOp that is a single iteration reduction loop
+/// into its containing block.
+LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
+                                             bool siblingFusionUser) {
+  // Check if the reduction loop is a single iteration loop.
+  Optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount || tripCount.getValue() != 1)
+    return failure();
+  auto iterOperands = forOp.getIterOperands();
+  auto *parentOp = forOp->getParentOp();
+  if (!isa<AffineForOp>(parentOp))
+    return failure();
+  auto newOperands = forOp.getBody()->getTerminator()->getOperands();
+  OpBuilder b(parentOp);
+  // Replace the parent loop and add iteroperands and results from the `forOp`.
+  AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
+  AffineForOp newLoop = replaceForOpWithNewYields(
+      b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs());
+
+  // For sibling-fusion users, collect operations that use the results of the
+  // `forOp` outside the new parent loop that has absorbed all its iter args
+  // and operands. These operations will be moved later after the results
+  // have been replaced.
+  SetVector<Operation *> forwardSlice;
+  if (siblingFusionUser) {
+    for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
+      SetVector<Operation *> tmpForwardSlice;
+      getForwardSlice(forOp.getResult(i), &tmpForwardSlice);
+      forwardSlice.set_union(tmpForwardSlice);
+    }
+  }
+  // Update the results of the `forOp` in the new loop.
+  for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
+    forOp.getResult(i).replaceAllUsesWith(
+        newLoop.getResult(i + parentOp->getNumResults()));
+  }
+  // For sibling-fusion users, move operations that use the results of the
+  // `forOp` outside the new parent loop
+  if (siblingFusionUser) {
+    topologicalSort(forwardSlice);
+    for (Operation *op : llvm::reverse(forwardSlice))
+      op->moveAfter(newLoop);
+  }
+  // Replace the induction variable.
+  auto iv = forOp.getInductionVar();
+  iv.replaceAllUsesWith(newLoop.getInductionVar());
+  // Replace the iter args.
+  auto forOpIterArgs = forOp.getRegionIterArgs();
+  for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back(
+                                              forOpIterArgs.size()))) {
+    std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+  }
+  // Move the loop body operations, except for its terminator, to the loop's
+  // containing block.
+  forOp.getBody()->back().erase();
+  auto *parentBlock = forOp->getBlock();
+  parentBlock->getOperations().splice(Block::iterator(forOp),
+                                      forOp.getBody()->getOperations());
+  forOp.erase();
+  parentForOp.erase();
+  return success();
+}
+
 /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
 /// and source slice loop bounds specified in 'srcSlice'.
 void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
-                     const ComputationSliceState &srcSlice) {
+                     const ComputationSliceState &srcSlice,
+                     bool isInnermostSiblingInsertion) {
   // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
   OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
   BlockAndValueMapping mapper;
@@ -392,9 +457,22 @@ void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
     }
   }
 
-  // Promote any single iteration slice loops.
-  for (AffineForOp forOp : sliceLoops)
-    (void)promoteIfSingleIteration(forOp);
+  llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
+  auto srcIsUnitSlice = [&]() {
+    return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
+            (getSliceIterationCount(sliceTripCountMap) == 1));
+  };
+  // Fix up and if possible, eliminate single iteration loops.
+  for (AffineForOp forOp : sliceLoops) {
+    if (isLoopParallelAndContainsReduction(forOp) &&
+        isInnermostSiblingInsertion && srcIsUnitSlice())
+      // Patch reduction loop - only ones that are sibling-fused with the
+      // destination loop - into the parent loop.
+      (void)promoteSingleIterReductionLoop(forOp, true);
+    else
+      // Promote any single iteration slice loops.
+      (void)promoteIfSingleIteration(forOp);
+  }
 }
 
 /// Collect loop nest statistics (eg. loop trip count and operation count)
@@ -484,74 +562,6 @@ static int64_t getComputeCostHelper(
   return tripCount * opCount;
 }
 
-// TODO: extend this to handle multiple result maps.
-static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
-  assert(lbMap.getNumResults() == 1 && "expected single result bound map");
-  assert(ubMap.getNumResults() == 1 && "expected single result bound map");
-  assert(lbMap.getNumDims() == ubMap.getNumDims());
-  assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
-  AffineExpr lbExpr(lbMap.getResult(0));
-  AffineExpr ubExpr(ubMap.getResult(0));
-  auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
-                                         lbMap.getNumSymbols());
-  auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
-  if (!cExpr)
-    return None;
-  return cExpr.getValue();
-}
-
-// Return the number of iterations in the given slice.
-static uint64_t getSliceIterationCount(
-    const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
-  uint64_t iterCount = 1;
-  for (const auto &count : sliceTripCountMap) {
-    iterCount *= count.second;
-  }
-  return iterCount;
-}
-
-// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
-// nest surrounding represented by slice loop bounds in 'slice'.
-// Returns true on success, false otherwise (if a non-constant trip count
-// was encountered).
-// TODO: Make this work with non-unit step loops.
-static bool buildSliceTripCountMap(
-    const ComputationSliceState &slice,
-    llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
-  unsigned numSrcLoopIVs = slice.ivs.size();
-  // Populate map from AffineForOp -> trip count
-  for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
-    AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
-    auto *op = forOp.getOperation();
-    AffineMap lbMap = slice.lbs[i];
-    AffineMap ubMap = slice.ubs[i];
-    // If lower or upper bound maps are null or provide no results, it implies
-    // that source loop was not at all sliced, and the entire loop will be a
-    // part of the slice.
-    if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
-        ubMap.getNumResults() == 0) {
-      // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
-      if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
-        (*tripCountMap)[op] =
-            forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
-        continue;
-      }
-      Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
-      if (maybeConstTripCount.hasValue()) {
-        (*tripCountMap)[op] = maybeConstTripCount.getValue();
-        continue;
-      }
-      return false;
-    }
-    Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
-    // Slice bounds are created with a constant ub - lb 
diff erence.
-    if (!tripCount.hasValue())
-      return false;
-    (*tripCountMap)[op] = tripCount.getValue();
-  }
-  return true;
-}
-
 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
 /// Currently, the total cost is computed by counting the total operation
 /// instance count (i.e. total number of operations in the loop body * loop

diff  --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index 14a2bf0223e27..650a8adcdddaf 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -3150,6 +3150,143 @@ func @no_fusion_cannot_compute_valid_slice() {
 
 // -----
 
+// MAXIMAL-LABEL:   func @reduce_add_f32_f32(
+func @reduce_add_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1: memref<1x64xf32, 1>, %arg2: memref<1x64xf32, 1>) {
+  %cst_0 = constant 0.000000e+00 : f32
+  %cst_1 = constant 1.000000e+00 : f32
+  %0 = memref.alloca() : memref<f32, 1>
+  %1 = memref.alloca() : memref<f32, 1>
+  affine.for %arg3 = 0 to 1 {
+    affine.for %arg4 = 0 to 64 {
+      %accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst_0) -> f32 {
+        %4 = affine.load %arg0[%arg5, %arg4] : memref<64x64xf32, 1>
+        %5 = addf %prevAccum, %4 : f32
+        affine.yield %5 : f32
+      }
+      %accum_dbl = addf %accum, %accum : f32
+      affine.store %accum_dbl, %arg1[%arg3, %arg4] : memref<1x64xf32, 1>
+    }
+  }
+  affine.for %arg3 = 0 to 1 {
+    affine.for %arg4 = 0 to 64 {
+      %accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst_1) -> f32 {
+        %4 = affine.load %arg0[%arg5, %arg4] : memref<64x64xf32, 1>
+        %5 = mulf %prevAccum, %4 : f32
+        affine.yield %5 : f32
+      }
+      %accum_sqr = mulf %accum, %accum : f32
+      affine.store %accum_sqr, %arg2[%arg3, %arg4] : memref<1x64xf32, 1>
+    }
+  }
+  return
+}
+// The two loops here get maximally sibling-fused at the innermost
+// insertion point. Test checks  if the innermost reduction loop of the fused loop
+// gets promoted into its outerloop.
+// MAXIMAL-SAME:                             %[[arg_0:.*]]: memref<64x64xf32, 1>,
+// MAXIMAL-SAME:                             %[[arg_1:.*]]: memref<1x64xf32, 1>,
+// MAXIMAL-SAME:                             %[[arg_2:.*]]: memref<1x64xf32, 1>) {
+// MAXIMAL:             %[[cst:.*]] = constant 0 : index
+// MAXIMAL-NEXT:        %[[cst_0:.*]] = constant 0.000000e+00 : f32
+// MAXIMAL-NEXT:        %[[cst_1:.*]] = constant 1.000000e+00 : f32
+// MAXIMAL:             affine.for %[[idx_0:.*]] = 0 to 1 {
+// MAXIMAL-NEXT:          affine.for %[[idx_1:.*]] = 0 to 64 {
+// MAXIMAL-NEXT:            %[[results:.*]]:2 = affine.for %[[idx_2:.*]] = 0 to 64 iter_args(%[[iter_0:.*]] = %[[cst_1]], %[[iter_1:.*]] = %[[cst_0]]) -> (f32, f32) {
+// MAXIMAL-NEXT:              %[[val_0:.*]] = affine.load %[[arg_0]][%[[idx_2]], %[[idx_1]]] : memref<64x64xf32, 1>
+// MAXIMAL-NEXT:              %[[reduc_0:.*]] = addf %[[iter_1]], %[[val_0]] : f32
+// MAXIMAL-NEXT:              %[[val_1:.*]] = affine.load %[[arg_0]][%[[idx_2]], %[[idx_1]]] : memref<64x64xf32, 1>
+// MAXIMAL-NEXT:              %[[reduc_1:.*]] = mulf %[[iter_0]], %[[val_1]] : f32
+// MAXIMAL-NEXT:              affine.yield %[[reduc_1]], %[[reduc_0]] : f32, f32
+// MAXIMAL-NEXT:            }
+// MAXIMAL-NEXT:            %[[reduc_0_dbl:.*]] = addf %[[results:.*]]#1, %[[results]]#1 : f32
+// MAXIMAL-NEXT:            affine.store %[[reduc_0_dbl]], %[[arg_1]][%[[cst]], %[[idx_1]]] : memref<1x64xf32, 1>
+// MAXIMAL-NEXT:            %[[reduc_1_sqr:.*]] = mulf %[[results]]#0, %[[results]]#0 : f32
+// MAXIMAL-NEXT:            affine.store %[[reduc_1_sqr]], %[[arg_2]][%[[idx_0]], %[[idx_1]]] : memref<1x64xf32, 1>
+// MAXIMAL-NEXT:          }
+// MAXIMAL-NEXT:        }
+// MAXIMAL-NEXT:        return
+// MAXIMAL-NEXT:      }
+
+// -----
+
+// CHECK-LABEL:   func @reduce_add_non_innermost
+func @reduce_add_non_innermost(%arg0: memref<64x64xf32, 1>, %arg1: memref<1x64xf32, 1>, %arg2: memref<1x64xf32, 1>) {
+  %cst = constant 0.000000e+00 : f32
+  %cst_0 = constant 1.000000e+00 : f32
+  %0 = memref.alloca() : memref<f32, 1>
+  %1 = memref.alloca() : memref<f32, 1>
+  affine.for %arg3 = 0 to 1 {
+    affine.for %arg4 = 0 to 64 {
+      %accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst) -> f32 {
+        %4 = affine.load %arg0[%arg5, %arg4] : memref<64x64xf32, 1>
+        %5 = addf %prevAccum, %4 : f32
+        affine.yield %5 : f32
+      }
+      %accum_dbl = addf %accum, %accum : f32
+      affine.store %accum_dbl, %arg1[%arg3, %arg4] : memref<1x64xf32, 1>
+    }
+  }
+  affine.for %arg3 = 0 to 1 {
+    affine.for %arg4 = 0 to 64 {
+      %accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst_0) -> f32 {
+        %4 = affine.load %arg0[%arg5, %arg4] : memref<64x64xf32, 1>
+        %5 = mulf %prevAccum, %4 : f32
+        affine.yield %5 : f32
+      }
+      %accum_sqr = mulf %accum, %accum : f32
+      affine.store %accum_sqr, %arg2[%arg3, %arg4] : memref<1x64xf32, 1>
+    }
+  }
+  return
+}
+// Test checks the loop structure is preserved after sibling fusion.
+// CHECK:         affine.for
+// CHECK-NEXT:      affine.for
+// CHECK-NEXT:        affine.for
+// CHECK             affine.for
+
+// -----
+func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : memref<1x64xf32, 1>, %arg2 : memref<1x64xf32, 1>) {
+    %cst_0 = constant 0.000000e+00 : f32
+    %cst_1 = constant 1.000000e+00 : f32
+    affine.for %arg3 = 0 to 1 {
+      affine.for %arg4 = 0 to 64 {
+        %accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst_0) -> f32 {
+          %4 = affine.load %arg0[%arg5, %arg4] : memref<64x64xf32, 1>
+          %5 = addf %prevAccum, %4 : f32
+          affine.yield %5 : f32
+        }
+        %accum_dbl = addf %accum, %accum : f32
+        affine.store %accum_dbl, %arg1[%arg3, %arg4] : memref<1x64xf32, 1>
+      }
+    }
+    affine.for %arg3 = 0 to 1 {
+      affine.for %arg4 = 0 to 64 {
+        // Following loop  trip count does not match the corresponding source trip count.
+        %accum = affine.for %arg5 = 0 to 32 iter_args (%prevAccum = %cst_1) -> f32 {
+          %4 = affine.load %arg0[%arg5, %arg4] : memref<64x64xf32, 1>
+          %5 = mulf %prevAccum, %4 : f32
+          affine.yield %5 : f32
+        }
+        %accum_sqr = mulf %accum, %accum : f32
+        affine.store %accum_sqr, %arg2[%arg3, %arg4] : memref<1x64xf32, 1>
+      }
+    }
+    return
+}
+// Test checks the loop structure is preserved after sibling fusion
+// since the destination loop and source loop trip counts do not
+// match.
+// MAXIMAL-LABEL:   func @reduce_add_non_maximal_f32_f32(
+// MAXIMAL:        %[[cst_0:.*]] = constant 0.000000e+00 : f32
+// MAXIMAL-NEXT:        %[[cst_1:.*]] = constant 1.000000e+00 : f32
+// MAXIMAL-NEXT:           affine.for %[[idx_0:.*]]= 0 to 1 {
+// MAXIMAL-NEXT:             affine.for %[[idx_1:.*]] = 0 to 64 {
+// MAXIMAL-NEXT:               %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) {
+// MAXIMAL-NEXT:                 %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) {
+
+// -----
+
 // CHECK-LABEL: func @fuse_large_number_of_loops
 func @fuse_large_number_of_loops(%arg0: memref<20x10xf32, 1>, %arg1: memref<20x10xf32, 1>, %arg2: memref<20x10xf32, 1>, %arg3: memref<20x10xf32, 1>, %arg4: memref<20x10xf32, 1>, %arg5: memref<f32, 1>, %arg6: memref<f32, 1>, %arg7: memref<f32, 1>, %arg8: memref<f32, 1>, %arg9: memref<20x10xf32, 1>, %arg10: memref<20x10xf32, 1>, %arg11: memref<20x10xf32, 1>, %arg12: memref<20x10xf32, 1>) {
   %cst = constant 1.000000e+00 : f32


        


More information about the Mlir-commits mailing list