[Mlir-commits] [mlir] 7ab14b8 - [mlir] Unroll-and-jam loops with iter_args.

Amy Zhuang llvmlistbot at llvm.org
Tue Sep 28 14:39:54 PDT 2021


Author: Amy Zhuang
Date: 2021-09-28T14:13:27-07:00
New Revision: 7ab14b8886d9ddaca1f8fc8a34ef8f03af208f26

URL: https://github.com/llvm/llvm-project/commit/7ab14b8886d9ddaca1f8fc8a34ef8f03af208f26
DIFF: https://github.com/llvm/llvm-project/commit/7ab14b8886d9ddaca1f8fc8a34ef8f03af208f26.diff

LOG: [mlir] Unroll-and-jam loops with iter_args.

Unroll-and-jam currently doesn't work when the loop being unroll-and-jammed
or any of its inner loops has iter_args. This patch modifies the
unroll-and-jam utility to support loops with iter_args.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D110085

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineAnalysis.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/Transforms/LoopUtils.h
    mlir/lib/Analysis/AffineAnalysis.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp
    mlir/test/Dialect/Affine/unroll-jam.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h
index 849d22e6938fb..1e5c5b8f87c6e 100644
--- a/mlir/include/mlir/Analysis/AffineAnalysis.h
+++ b/mlir/include/mlir/Analysis/AffineAnalysis.h
@@ -40,6 +40,10 @@ struct LoopReduction {
   Value value;
 };
 
+/// Populate `supportedReductions` with descriptors of the supported reductions.
+void getSupportedReductions(
+    AffineForOp forOp, SmallVectorImpl<LoopReduction> &supportedReductions);
+
 /// Returns true if `forOp' is a parallel loop. If `parallelReductions` is
 /// provided, populates it with descriptors of the parallelizable reductions and
 /// treats them as not preventing parallelization.

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 67754f1cc3eb9..79dcd8a2b315e 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -265,6 +265,10 @@ def AffineForOp : Affine_Op<"for",
     /// Returns operands for the upper bound map.
     operand_range getUpperBoundOperands();
 
+    /// Returns operands for the lower and upper bound maps with the operands
+    /// for the lower bound map in front of those for the upper bound map.
+    operand_range getControlOperands();
+
     /// Returns information about the lower bound as a single object.
     AffineBound getLowerBound();
 

diff  --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h
index 8c851215e3c08..f2651be24b15c 100644
--- a/mlir/include/mlir/Transforms/LoopUtils.h
+++ b/mlir/include/mlir/Transforms/LoopUtils.h
@@ -66,8 +66,10 @@ void getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
 void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
                              scf::ForOp root);
 
-/// Unrolls and jams this loop by the specified factor. Returns success if the
-/// loop is successfully unroll-jammed.
+/// Unrolls and jams this loop by the specified factor. `forOp` can be a loop
+/// with iteration arguments performing supported reductions and its inner loops
+/// can have iteration arguments. Returns success if the loop is successfully
+/// unroll-jammed.
 LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
                                     uint64_t unrollJamFactor);
 

diff  --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 1966305cc40e3..d3e4f4bbe9019 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -69,6 +69,20 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
   return reducedVal;
 }
 
+/// Populate `supportedReductions` with descriptors of the supported reductions.
+void mlir::getSupportedReductions(
+    AffineForOp forOp, SmallVectorImpl<LoopReduction> &supportedReductions) {
+  unsigned numIterArgs = forOp.getNumIterOperands();
+  if (numIterArgs == 0)
+    return;
+  supportedReductions.reserve(numIterArgs);
+  for (unsigned i = 0; i < numIterArgs; ++i) {
+    AtomicRMWKind kind;
+    if (Value value = getSupportedReduction(forOp, i, kind))
+      supportedReductions.emplace_back(LoopReduction{kind, i, value});
+  }
+}
+
 /// Returns true if `forOp' is a parallel loop. If `parallelReductions` is
 /// provided, populates it with descriptors of the parallelizable reductions and
 /// treats them as not preventing parallelization.
@@ -83,13 +97,7 @@ bool mlir::isLoopParallel(AffineForOp forOp,
 
   // Find supported reductions of requested.
   if (parallelReductions) {
-    parallelReductions->reserve(forOp.getNumIterOperands());
-    for (unsigned i = 0; i < numIterArgs; ++i) {
-      AtomicRMWKind kind;
-      if (Value value = getSupportedReduction(forOp, i, kind))
-        parallelReductions->emplace_back(LoopReduction{kind, i, value});
-    }
-
+    getSupportedReductions(forOp, *parallelReductions);
     // Return later to allow for identifying all parallel reductions even if the
     // loop is not parallel.
     if (parallelReductions->size() != numIterArgs)

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 7d89d5911e6de..4e1cf4c57f6bc 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1770,6 +1770,11 @@ AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
               getUpperBoundMap().getNumInputs()};
 }
 
+AffineForOp::operand_range AffineForOp::getControlOperands() {
+  return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs() +
+                               getUpperBoundMap().getNumInputs()};
+}
+
 bool AffineForOp::matchingBoundOperandList() {
   auto lbMap = getLowerBoundMap();
   auto ubMap = getUpperBoundMap();

diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp
index 55964db7810d1..76da2b07ef191 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp
@@ -34,6 +34,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "PassDetail.h"
+#include "mlir/Analysis/AffineAnalysis.h"
 #include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Passes.h"

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 4427001be2204..bf472ad46d4b2 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -1126,6 +1126,40 @@ static void generateUnrolledLoop(
   loopBodyBlock->getTerminator()->setOperands(lastYielded);
 }
 
+/// Helper to generate cleanup loop for unroll or unroll-and-jam when the trip
+/// count is not a multiple of `unrollFactor`.
+static void generateCleanupLoopForUnroll(AffineForOp forOp,
+                                         uint64_t unrollFactor) {
+  // Insert the cleanup loop right after 'forOp'.
+  OpBuilder builder(forOp->getBlock(), std::next(Block::iterator(forOp)));
+  auto cleanupForOp = cast<AffineForOp>(builder.clone(*forOp));
+
+  // Update uses of `forOp` results. `cleanupForOp` should use `forOp` result
+  // and produce results for the original users of `forOp` results.
+  auto results = forOp.getResults();
+  auto cleanupResults = cleanupForOp.getResults();
+  auto cleanupIterOperands = cleanupForOp.getIterOperands();
+
+  for (auto e : llvm::zip(results, cleanupResults, cleanupIterOperands)) {
+    std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
+    cleanupForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e));
+  }
+
+  AffineMap cleanupMap;
+  SmallVector<Value, 4> cleanupOperands;
+  getCleanupLoopLowerBound(forOp, unrollFactor, cleanupMap, cleanupOperands);
+  assert(cleanupMap &&
+         "cleanup loop lower bound map for single result lower bound maps "
+         "can always be determined");
+  cleanupForOp.setLowerBound(cleanupOperands, cleanupMap);
+  // Promote the loop body up if this has turned into a single iteration loop.
+  (void)promoteIfSingleIteration(cleanupForOp);
+
+  // Adjust upper bound of the original loop; this is the same as the lower
+  // bound of the cleanup loop.
+  forOp.setUpperBound(cleanupOperands, cleanupMap);
+}
+
 /// Unrolls this loop by the specified factor. Returns success if the loop
 /// is successfully unrolled.
 LogicalResult mlir::loopUnrollByFactor(
@@ -1156,34 +1190,8 @@ LogicalResult mlir::loopUnrollByFactor(
     return failure();
 
   // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
-  if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
-    OpBuilder builder(forOp->getBlock(), std::next(Block::iterator(forOp)));
-    auto cleanupForOp = cast<AffineForOp>(builder.clone(*forOp));
-
-    // Update users of loop results.
-    auto results = forOp.getResults();
-    auto cleanupResults = cleanupForOp.getResults();
-    auto cleanupIterOperands = cleanupForOp.getIterOperands();
-
-    for (auto e : llvm::zip(results, cleanupResults, cleanupIterOperands)) {
-      std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
-      cleanupForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e));
-    }
-
-    AffineMap cleanupMap;
-    SmallVector<Value, 4> cleanupOperands;
-    getCleanupLoopLowerBound(forOp, unrollFactor, cleanupMap, cleanupOperands);
-    assert(cleanupMap &&
-           "cleanup loop lower bound map for single result lower bound maps "
-           "can always be determined");
-    cleanupForOp.setLowerBound(cleanupOperands, cleanupMap);
-    // Promote the loop body up if this has turned into a single iteration loop.
-    (void)promoteIfSingleIteration(cleanupForOp);
-
-    // Adjust upper bound of the original loop; this is the same as the lower
-    // bound of the cleanup loop.
-    forOp.setUpperBound(cleanupOperands, cleanupMap);
-  }
+  if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0)
+    generateCleanupLoopForUnroll(forOp, unrollFactor);
 
   ValueRange iterArgs(forOp.getRegionIterArgs());
   auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
@@ -1330,37 +1338,52 @@ LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp,
   return loopUnrollJamByFactor(forOp, unrollJamFactor);
 }
 
-/// Unrolls and jams this loop by the specified factor.
-LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
-                                          uint64_t unrollJamFactor) {
-  // Gathers all maximal sub-blocks of operations that do not themselves
-  // include a for op (a operation could have a descendant for op though
-  // in its tree).  Ignore the block terminators.
-  struct JamBlockGatherer {
-    // Store iterators to the first and last op of each sub-block found.
-    std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-
-    // This is a linear time walk.
-    void walk(Operation *op) {
-      for (auto &region : op->getRegions())
-        for (auto &block : region)
-          walk(block);
+/// Check if all control operands of all loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(AffineForOp forOp) {
+  auto walkResult = forOp.walk([&](AffineForOp aForOp) {
+    for (auto controlOperand : aForOp.getControlOperands()) {
+      if (!forOp.isDefinedOutsideOfLoop(controlOperand))
+        return WalkResult::interrupt();
     }
+    return WalkResult::advance();
+  });
+  if (walkResult.wasInterrupted())
+    return false;
+  return true;
+}
 
-    void walk(Block &block) {
-      for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
-        auto subBlockStart = it;
-        while (it != e && !isa<AffineForOp>(&*it))
-          ++it;
-        if (it != subBlockStart)
-          subBlocks.push_back({subBlockStart, std::prev(it)});
-        // Process all for ops that appear next.
-        while (it != e && isa<AffineForOp>(&*it))
-          walk(&*it++);
-      }
+// Gathers all maximal sub-blocks of operations that do not themselves
+// include a for op (a operation could have a descendant for op though
+// in its tree).  Ignore the block terminators.
+struct JamBlockGatherer {
+  // Store iterators to the first and last op of each sub-block found.
+  std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+  // This is a linear time walk.
+  void walk(Operation *op) {
+    for (auto &region : op->getRegions())
+      for (auto &block : region)
+        walk(block);
+  }
+
+  void walk(Block &block) {
+    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
+      auto subBlockStart = it;
+      while (it != e && !isa<AffineForOp>(&*it))
+        ++it;
+      if (it != subBlockStart)
+        subBlocks.push_back({subBlockStart, std::prev(it)});
+      // Process all for ops that appear next.
+      while (it != e && isa<AffineForOp>(&*it))
+        walk(&*it++);
     }
-  };
+  }
+};
 
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
+                                          uint64_t unrollJamFactor) {
   assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
 
   if (unrollJamFactor == 1)
@@ -1387,31 +1410,83 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
     return failure();
   }
 
+  // If any control operand of any inner loop of `forOp` is defined within
+  // `forOp`, no unroll jam.
+  if (!areInnerBoundsInvariant(forOp))
+    return failure();
+
   // Gather all sub-blocks to jam upon the loop being unrolled.
   JamBlockGatherer jbg;
   jbg.walk(forOp);
   auto &subBlocks = jbg.subBlocks;
 
+  // Collect loops with iter_args.
+  SmallVector<AffineForOp, 4> loopsWithIterArgs;
+  forOp.walk([&](AffineForOp aForOp) {
+    if (aForOp.getNumIterOperands() > 0)
+      loopsWithIterArgs.push_back(aForOp);
+  });
+
+  // Get supported reductions to be used for creating reduction ops at the end.
+  SmallVector<LoopReduction> reductions;
+  if (forOp.getNumIterOperands() > 0)
+    getSupportedReductions(forOp, reductions);
+
   // Generate the cleanup loop if trip count isn't a multiple of
   // unrollJamFactor.
-  if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) {
-    // Insert the cleanup loop right after 'forOp'.
-    OpBuilder builder(forOp->getBlock(), std::next(Block::iterator(forOp)));
-    auto cleanupAffineForOp = cast<AffineForOp>(builder.clone(*forOp));
-    // Adjust the lower bound of the cleanup loop; its upper bound is the same
-    // as the original loop's upper bound.
-    AffineMap cleanupMap;
-    SmallVector<Value, 4> cleanupOperands;
-    getCleanupLoopLowerBound(forOp, unrollJamFactor, cleanupMap,
-                             cleanupOperands);
-    cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap);
-
-    // Promote the cleanup loop if it has turned into a single iteration loop.
-    (void)promoteIfSingleIteration(cleanupAffineForOp);
-
-    // Adjust the upper bound of the original loop - it will be the same as the
-    // cleanup loop's lower bound. Its lower bound remains unchanged.
-    forOp.setUpperBound(cleanupOperands, cleanupMap);
+  if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0)
+    generateCleanupLoopForUnroll(forOp, unrollJamFactor);
+
+  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+  // iteration. There are (`unrollJamFactor` - 1) iterations.
+  SmallVector<BlockAndValueMapping, 4> operandMaps(unrollJamFactor - 1);
+
+  // For any loop with iter_args, replace it with a new loop that has
+  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+  // operands.
+  SmallVector<AffineForOp, 4> newLoopsWithIterArgs;
+  OpBuilder builder(forOp.getContext());
+  for (AffineForOp oldForOp : loopsWithIterArgs) {
+    SmallVector<Value, 4> dupIterOperands, dupIterArgs, dupYieldOperands;
+    ValueRange oldIterOperands = oldForOp.getIterOperands();
+    ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+    ValueRange oldYieldOperands =
+        cast<AffineYieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+    // Get additional iterOperands, iterArgs, and yield operands. We will
+    // fix iterOperands and yield operands after cloning of sub-blocks.
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+      dupIterArgs.append(oldIterArgs.begin(), oldIterArgs.end());
+      dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+    }
+    // Create a new loop with additional iterOperands, iter_args and yield
+    // operands. This new loop will take the loop body of the original loop.
+    AffineForOp newForOp = mlir::replaceForOpWithNewYields(
+        builder, oldForOp, dupIterOperands, dupYieldOperands, dupIterArgs);
+    newLoopsWithIterArgs.push_back(newForOp);
+    // `forOp` has been replaced with a new loop.
+    if (oldForOp == forOp)
+      forOp = newForOp;
+    assert(oldForOp.use_empty() && "old for op should not have any user");
+    oldForOp.erase();
+    // Update `operandMaps` for `newForOp` iterArgs and results.
+    ValueRange newIterArgs = newForOp.getRegionIterArgs();
+    unsigned oldNumIterArgs = oldIterArgs.size();
+    ValueRange newResults = newForOp.getResults();
+    unsigned oldNumResults = newResults.size() / unrollJamFactor;
+    assert(oldNumIterArgs == oldNumResults &&
+           "oldNumIterArgs must be the same as oldNumResults");
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+        // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+        // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+        // to those in the `i`th new set.
+        operandMaps[i - 1].map(newIterArgs[j],
+                               newIterArgs[i * oldNumIterArgs + j]);
+        operandMaps[i - 1].map(newResults[j],
+                               newResults[i * oldNumResults + j]);
+      }
+    }
   }
 
   // Scale the step of loop being unroll-jammed by the unroll-jam factor.
@@ -1421,8 +1496,6 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
   auto forOpIV = forOp.getInductionVar();
   // Unroll and jam (appends unrollJamFactor - 1 additional copies).
   for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
-    // Operand map persists across all sub-blocks.
-    BlockAndValueMapping operandMap;
     for (auto &subBlock : subBlocks) {
       // Builder to insert unroll-jammed bodies. Insert right at the end of
       // sub-block.
@@ -1431,16 +1504,63 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
       // If the induction variable is used, create a remapping to the value for
       // this unrolled instance.
       if (!forOpIV.use_empty()) {
-        // iv' = iv + i, i = 1 to unrollJamFactor-1.
+        // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
         auto d0 = builder.getAffineDimExpr(0);
         auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
         auto ivUnroll =
             builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
-        operandMap.map(forOpIV, ivUnroll);
+        operandMaps[i - 1].map(forOpIV, ivUnroll);
       }
       // Clone the sub-block being unroll-jammed.
       for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
-        builder.clone(*it, operandMap);
+        builder.clone(*it, operandMaps[i - 1]);
+    }
+    // Fix iterOperands and yield op operands of newly created loops.
+    for (auto newForOp : newLoopsWithIterArgs) {
+      unsigned oldNumIterOperands =
+          newForOp.getNumIterOperands() / unrollJamFactor;
+      unsigned numControlOperands = newForOp.getNumControlOperands();
+      auto yieldOp = cast<AffineYieldOp>(newForOp.getBody()->getTerminator());
+      unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+      assert(oldNumIterOperands == oldNumYieldOperands &&
+             "oldNumIterOperands must be the same as oldNumYieldOperands");
+      for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+        // The `i`th duplication of an old iterOperand or yield op operand
+        // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+        // if such mapped value exists.
+        newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+                            operandMaps[i - 1].lookupOrDefault(
+                                newForOp.getOperand(numControlOperands + j)));
+        yieldOp.setOperand(
+            i * oldNumYieldOperands + j,
+            operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
+      }
+    }
+  }
+  if (forOp.getNumResults() > 0) {
+    // Create reduction ops to combine every `unrollJamFactor` related results
+    // into one value. For example, for %0:2 = affine.for ... and addf, we add
+    // %1 = addf %0#0, %0#1, and replace the following uses of %0#0 with %1.
+    builder.setInsertionPointAfter(forOp);
+    auto loc = forOp.getLoc();
+    unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
+    for (LoopReduction &reduction : reductions) {
+      unsigned pos = reduction.iterArgPosition;
+      Value lhs = forOp.getResult(pos);
+      Value rhs;
+      SmallPtrSet<Operation *, 4> newOps;
+      for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+        rhs = forOp.getResult(i * oldNumResults + pos);
+        // Create ops based on reduction type.
+        lhs = getReductionOp(reduction.kind, builder, loc, lhs, rhs);
+        if (!lhs)
+          return failure();
+        Operation *op = lhs.getDefiningOp();
+        assert(op && "Reduction op should have been created");
+        newOps.insert(op);
+      }
+      // Replace all uses except those in newly created reduction ops.
+      forOp.getResult(pos).replaceAllUsesExcept(lhs, newOps);
     }
   }
 

diff  --git a/mlir/test/Dialect/Affine/unroll-jam.mlir b/mlir/test/Dialect/Affine/unroll-jam.mlir
index 9b09918827807..84f276dd993e5 100644
--- a/mlir/test/Dialect/Affine/unroll-jam.mlir
+++ b/mlir/test/Dialect/Affine/unroll-jam.mlir
@@ -122,3 +122,396 @@ func @loop_nest_symbolic_and_min_upper_bound(%M : index, %N : index, %K : index)
 // CHECK-NEXT:    }
 // CHECK-NEXT:  }
 // CHECK-NEXT:  return
+
+// The inner loop trip count changes each iteration of outer loop.
+// Do no unroll-and-jam.
+// CHECK-LABEL: func @no_unroll_jam_dependent_ubound
+func @no_unroll_jam_dependent_ubound(%in0: memref<?xf32, 1>) {
+  affine.for %i = 0 to 100 {
+    affine.for %k = 0 to affine_map<(d0) -> (d0 + 1)>(%i) {
+      %y = "addi32"(%k, %k) : (index, index) -> i32
+    }
+  }
+  return
+}
+// CHECK:      affine.for [[IV0:%arg[0-9]+]] = 0 to 100 {
+// CHECK-NEXT:   affine.for [[IV1:%arg[0-9]+]] = 0 to [[$MAP_PLUS_1]]([[IV0]]) {
+// CHECK-NEXT:     "addi32"([[IV1]], [[IV1]])
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+
+// Inner loop with one iter_arg.
+// CHECK-LABEL: func @unroll_jam_one_iter_arg
+func @unroll_jam_one_iter_arg() {
+  affine.for %i = 0 to 101 {
+    %cst = constant 1 : i32
+    %x = "addi32"(%i, %i) : (index, index) -> i32
+    %red = affine.for %j = 0 to 17 iter_args(%acc = %cst) -> (i32) {
+      %y = "bar"(%i, %j, %acc) : (index, index, i32) -> i32
+      affine.yield %y : i32
+    }
+    %w = "foo"(%i, %x, %red) : (index, i32, i32) -> i32
+  }
+  return
+}
+// CHECK:      affine.for [[IV0:%arg[0-9]+]] = 0 to 100 step 2 {
+// CHECK-NEXT:   [[CONST1:%[a-zA-Z0-9_]*]] = constant 1 : i32
+// CHECK-NEXT:   [[RES1:%[0-9]+]] = "addi32"([[IV0]], [[IV0]])
+// CHECK-NEXT:   [[INC:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   [[CONST2:%[a-zA-Z0-9_]*]] = constant 1 : i32
+// CHECK-NEXT:   [[RES2:%[0-9]+]] = "addi32"([[INC]], [[INC]])
+// CHECK-NEXT:   [[RES3:%[0-9]+]]:2 = affine.for [[IV1:%arg[0-9]+]] = 0 to 17 iter_args([[ACC1:%arg[0-9]+]] = [[CONST1]], [[ACC2:%arg[0-9]+]] = [[CONST2]]) -> (i32, i32) {
+// CHECK-NEXT:     [[RES4:%[0-9]+]] = "bar"([[IV0]], [[IV1]], [[ACC1]])
+// CHECK-NEXT:     [[INC1:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:     [[RES5:%[0-9]+]] = "bar"([[INC1]], [[IV1]], [[ACC2]])
+// CHECK-NEXT:     affine.yield [[RES4]], [[RES5]]
+// CHECK-NEXT:   }
+// CHECK:        "foo"([[IV0]], [[RES1]], [[RES3]]#0)
+// CHECK-NEXT:   affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   "foo"({{.*}}, [[RES2]], [[RES3]]#1)
+// CHECK:      }
+// Cleanup loop (single iteration).
+// CHECK:      constant 1 : i32
+// CHECK-NEXT: "addi32"(%c100, %c100)
+// CHECK-NEXT: [[RES6:%[0-9]+]] = affine.for
+// CHECK-NEXT:   [[RES7:%[0-9]+]] = "bar"(%c100, {{.*}}, {{.*}})
+// CHECK-NEXT:   affine.yield [[RES7]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: "foo"(%c100, %{{.*}}, [[RES6]])
+// CHECK-NEXT: return
+
+// Inner loop with multiple iter_args.
+// CHECK-LABEL: func @unroll_jam_iter_args
+func @unroll_jam_iter_args() {
+  affine.for %i = 0 to 101 {
+    %cst = constant 0 : i32
+    %cst1 = constant 1 : i32
+    %x = "addi32"(%i, %i) : (index, index) -> i32
+    %red:2 = affine.for %j = 0 to 17 iter_args(%acc = %cst, %acc1 = %cst1) -> (i32, i32) {
+      %y = "bar"(%i, %j, %acc) : (index, index, i32) -> i32
+      %z = "bar1"(%i, %j, %acc1) : (index, index, i32) -> i32
+      affine.yield %y, %z : i32, i32
+    }
+    %w = "foo"(%i, %x, %red#0, %red#1) : (index, i32, i32, i32) -> i32
+  }
+  return
+}
+// CHECK:      affine.for [[IV0:%arg[0-9]+]] = 0 to 100 step 2 {
+// CHECK-NEXT:   [[CONST0:%[a-zA-Z0-9_]*]] = constant 0 : i32
+// CHECK-NEXT:   [[CONST1:%[a-zA-Z0-9_]*]] = constant 1 : i32
+// CHECK-NEXT:   [[RES1:%[0-9]+]] = "addi32"([[IV0]], [[IV0]])
+// CHECK-NEXT:   [[INC:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   [[CONST2:%[a-zA-Z0-9_]*]] = constant 0 : i32
+// CHECK-NEXT:   [[CONST3:%[a-zA-Z0-9_]*]] = constant 1 : i32
+// CHECK-NEXT:   [[RES2:%[0-9]+]] = "addi32"([[INC]], [[INC]])
+// CHECK-NEXT:   [[RES3:%[0-9]+]]:4 = affine.for [[IV1:%arg[0-9]+]] = 0 to 17 iter_args([[ACC0:%arg[0-9]+]] = [[CONST0]], [[ACC1:%arg[0-9]+]] = [[CONST1]],
+// CHECK-SAME:   [[ACC2:%arg[0-9]+]] = [[CONST2]], [[ACC3:%arg[0-9]+]] = [[CONST3]]) -> (i32, i32, i32, i32) {
+// CHECK-NEXT:     [[RES4:%[0-9]+]] = "bar"([[IV0]], [[IV1]], [[ACC0]])
+// CHECK-NEXT:     [[RES5:%[0-9]+]] = "bar1"([[IV0]], [[IV1]], [[ACC1]])
+// CHECK-NEXT:     [[INC1:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:     [[RES6:%[0-9]+]] = "bar"([[INC1]], [[IV1]], [[ACC2]])
+// CHECK-NEXT:     [[RES7:%[0-9]+]] = "bar1"([[INC1]], [[IV1]], [[ACC3]])
+// CHECK-NEXT:     affine.yield [[RES4]], [[RES5]], [[RES6]], [[RES7]]
+// CHECK-NEXT:   }
+// CHECK:        "foo"([[IV0]], [[RES1]], [[RES3]]#0, [[RES3]]#1)
+// CHECK-NEXT:   affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   "foo"({{.*}}, [[RES2]], [[RES3]]#2, [[RES3]]#3)
+// CHECK:      }
+// Cleanup loop (single iteration).
+// CHECK:      constant 0 : i32
+// CHECK-NEXT: constant 1 : i32
+// CHECK-NEXT: "addi32"(%c100, %c100)
+// CHECK-NEXT: [[RES8:%[0-9]+]]:2 = affine.for
+// CHECK-NEXT:   [[RES9:%[0-9]+]] = "bar"(%c100, {{.*}}, {{.*}})
+// CHECK-NEXT:   [[RES10:%[0-9]+]] = "bar1"(%c100, {{.*}}, {{.*}})
+// CHECK-NEXT:   affine.yield [[RES9]], [[RES10]] : i32, i32
+// CHECK-NEXT: }
+// CHECK-NEXT: "foo"(%c100, %{{.*}}, [[RES8]]#0, [[RES8]]#1)
+// CHECK-NEXT: return
+
+// When an iter operand is a function argument, do not replace any use of the
+// operand .
+// CHECK-LABEL: func @unroll_jam_iter_args_func_arg
+// CHECK-SAME:  [[INIT:%arg[0-9]+]]: i32
+func @unroll_jam_iter_args_func_arg(%in: i32) {
+  affine.for %i = 0 to 101 {
+    %x = "addi32"(%i, %i) : (index, index) -> i32
+    %red = affine.for %j = 0 to 17 iter_args(%acc = %in) -> (i32) {
+      %y = "bar"(%i, %j, %acc) : (index, index, i32) -> i32
+      affine.yield %y : i32
+    }
+    %w = "foo"(%i, %x, %red) : (index, i32, i32) -> i32
+  }
+  return
+}
+// CHECK:      affine.for [[IV0:%arg[0-9]+]] = 0 to 100 step 2 {
+// CHECK-NEXT:   [[RES1:%[0-9]+]] = "addi32"([[IV0]], [[IV0]])
+// CHECK-NEXT:   [[INC:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   [[RES2:%[0-9]+]] = "addi32"([[INC]], [[INC]])
+// CHECK-NEXT:   [[RES3:%[0-9]+]]:2 = affine.for [[IV1:%arg[0-9]+]] = 0 to 17 iter_args([[ACC1:%arg[0-9]+]] = [[INIT]], [[ACC2:%arg[0-9]+]] = [[INIT]]) -> (i32, i32) {
+// CHECK-NEXT:     [[RES4:%[0-9]+]] = "bar"([[IV0]], [[IV1]], [[ACC1]])
+// CHECK-NEXT:     [[INC1:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:     [[RES5:%[0-9]+]] = "bar"([[INC1]], [[IV1]], [[ACC2]])
+// CHECK-NEXT:     affine.yield [[RES4]], [[RES5]]
+// CHECK-NEXT:   }
+// CHECK:        "foo"([[IV0]], [[RES1]], [[RES3]]#0)
+// CHECK-NEXT:   affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   "foo"({{.*}}, [[RES2]], [[RES3]]#1)
+// CHECK:      }
+// Cleanup loop (single iteration).
+// CHECK:      "addi32"(%c100, %c100)
+// CHECK-NEXT: [[RES6:%[0-9]+]] = affine.for
+// CHECK-NEXT:   [[RES7:%[0-9]+]] = "bar"(%c100, {{.*}}, {{.*}})
+// CHECK-NEXT:   affine.yield [[RES7]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: "foo"(%c100, %{{.*}}, [[RES6]])
+// CHECK-NEXT: return
+
+// Nested inner loops, each with one iter_arg. The inner most loop uses its
+// outer loop's iter_arg as its iter operand.
+// CHECK-LABEL: func @unroll_jam_iter_args_nested
+func @unroll_jam_iter_args_nested() {
+  affine.for %i = 0 to 101 {
+    %cst = constant 1 : i32
+    %x = "addi32"(%i, %i) : (index, index) -> i32
+    %red = affine.for %j = 0 to 17 iter_args(%acc = %cst) -> (i32) {
+      %red1 = affine.for %k = 0 to 35 iter_args(%acc1 = %acc) -> (i32) {
+        %y = "bar"(%i, %j, %k, %acc1) : (index, index, index, i32) -> i32
+        affine.yield %y : i32
+      }
+      affine.yield %red1 : i32
+    }
+    %w = "foo"(%i, %x, %red) : (index, i32, i32) -> i32
+  }
+  return
+}
+// CHECK:      affine.for [[IV0:%arg[0-9]+]] = 0 to 100 step 2 {
+// CHECK-NEXT:   [[CONST1:%[a-zA-Z0-9_]*]] = constant 1 : i32
+// CHECK-NEXT:   [[RES1:%[0-9]+]] = "addi32"([[IV0]], [[IV0]])
+// CHECK-NEXT:   [[INC:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   [[CONST2:%[a-zA-Z0-9_]*]] = constant 1 : i32
+// CHECK-NEXT:   [[RES2:%[0-9]+]] = "addi32"([[INC]], [[INC]])
+// CHECK-NEXT:   [[RES3:%[0-9]+]]:2 = affine.for [[IV1:%arg[0-9]+]] = 0 to 17 iter_args([[ACC1:%arg[0-9]+]] = [[CONST1]], [[ACC2:%arg[0-9]+]] = [[CONST2]]) -> (i32, i32) {
+// CHECK-NEXT:     [[RES4:%[0-9]+]]:2 = affine.for [[IV2:%arg[0-9]+]] = 0 to 35 iter_args([[ACC3:%arg[0-9]+]] = [[ACC1]], [[ACC4:%arg[0-9]+]] = [[ACC2]]) -> (i32, i32) {
+// CHECK-NEXT:       [[RES5:%[0-9]+]] = "bar"([[IV0]], [[IV1]], [[IV2]], [[ACC3]])
+// CHECK-NEXT:       [[INC1:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:       [[RES6:%[0-9]+]] = "bar"([[INC1]], [[IV1]], [[IV2]], [[ACC4]])
+// CHECK-NEXT:       affine.yield [[RES5]], [[RES6]]
+// CHECK-NEXT:     }
+// CHECK-NEXT:     affine.yield [[RES4]]#0, [[RES4]]#1
+// CHECK-NEXT:   }
+// CHECK:        "foo"([[IV0]], [[RES1]], [[RES3]]#0)
+// CHECK-NEXT:   affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   "foo"({{.*}}, [[RES2]], [[RES3]]#1)
+// CHECK:      }
+// Cleanup loop (single iteration).
+// CHECK:      constant 1 : i32
+// CHECK-NEXT: "addi32"(%c100, %c100)
+// CHECK-NEXT: [[RES6:%[0-9]+]] = affine.for
+// CHECK-NEXT:   [[RES7:%[0-9]+]] = affine.for
+// CHECK-NEXT:     [[RES8:%[0-9]+]] = "bar"(%c100, {{.*}}, {{.*}}, {{.*}})
+// CHECK-NEXT:     affine.yield [[RES8]] : i32
+// CHECK-NEXT:   }
+// CHECK-NEXT:   affine.yield [[RES7]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: "foo"(%c100, %{{.*}}, [[RES6]])
+// CHECK-NEXT: return
+
+// Nested inner loops, each with one iter_arg. One loop uses its sibling loop's
+// result as its iter operand.
+// CHECK-LABEL: func @unroll_jam_iter_args_nested_affine_for_result
+func @unroll_jam_iter_args_nested_affine_for_result() {
+  affine.for %i = 0 to 101 {
+    %cst = constant 1 : i32
+    %x = "addi32"(%i, %i) : (index, index) -> i32
+    %red = affine.for %j = 0 to 17 iter_args(%acc = %cst) -> (i32) {
+      %red1 = affine.for %k = 0 to 35 iter_args(%acc1 = %acc) -> (i32) {
+        %y = "bar"(%i, %j, %k, %acc1) : (index, index, index, i32) -> i32
+        affine.yield %acc : i32
+      }
+      %red2 = affine.for %l = 0 to 36 iter_args(%acc2 = %red1) -> (i32) {
+        %y = "bar"(%i, %j, %l, %acc2) : (index, index, index, i32) -> i32
+        affine.yield %y : i32
+      }
+      affine.yield %red2 : i32
+    }
+    %w = "foo"(%i, %x, %red) : (index, i32, i32) -> i32
+  }
+  return
+}
+// CHECK:      affine.for [[IV0:%arg[0-9]+]] = 0 to 100 step 2 {
+// CHECK-NEXT:   [[CONST1:%[a-zA-Z0-9_]*]] = constant 1 : i32
+// CHECK-NEXT:   [[RES1:%[0-9]+]] = "addi32"([[IV0]], [[IV0]])
+// CHECK-NEXT:   [[INC:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   [[CONST2:%[a-zA-Z0-9_]*]] = constant 1 : i32
+// CHECK-NEXT:   [[RES2:%[0-9]+]] = "addi32"([[INC]], [[INC]])
+// CHECK-NEXT:   [[RES3:%[0-9]+]]:2 = affine.for [[IV1:%arg[0-9]+]] = 0 to 17 iter_args([[ACC1:%arg[0-9]+]] = [[CONST1]], [[ACC2:%arg[0-9]+]] = [[CONST2]]) -> (i32, i32) {
+// CHECK-NEXT:     [[RES4:%[0-9]+]]:2 = affine.for [[IV2:%arg[0-9]+]] = 0 to 35 iter_args([[ACC3:%arg[0-9]+]] = [[ACC1]], [[ACC4:%arg[0-9]+]] = [[ACC2]]) -> (i32, i32) {
+// CHECK-NEXT:       [[RES5:%[0-9]+]] = "bar"([[IV0]], [[IV1]], [[IV2]], [[ACC3]])
+// CHECK-NEXT:       [[INC1:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:       [[RES6:%[0-9]+]] = "bar"([[INC1]], [[IV1]], [[IV2]], [[ACC4]])
+// CHECK-NEXT:       affine.yield [[ACC1]], [[ACC2]]
+// CHECK-NEXT:     }
+// CHECK-NEXT:     [[RES14:%[0-9]+]]:2 = affine.for [[IV3:%arg[0-9]+]] = 0 to 36 iter_args([[ACC13:%arg[0-9]+]] = [[RES4]]#0, [[ACC14:%arg[0-9]+]] = [[RES4]]#1) -> (i32, i32) {
+// CHECK-NEXT:       [[RES15:%[0-9]+]] = "bar"([[IV0]], [[IV1]], [[IV3]], [[ACC13]])
+// CHECK-NEXT:       [[INC1:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:       [[RES16:%[0-9]+]] = "bar"([[INC1]], [[IV1]], [[IV3]], [[ACC14]])
+// CHECK-NEXT:       affine.yield [[RES15]], [[RES16]]
+// CHECK-NEXT:     }
+// CHECK-NEXT:     affine.yield [[RES14]]#0, [[RES14]]#1
+// CHECK-NEXT:   }
+// CHECK:        "foo"([[IV0]], [[RES1]], [[RES3]]#0)
+// CHECK-NEXT:   affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   "foo"({{.*}}, [[RES2]], [[RES3]]#1)
+// CHECK:      }
+// Cleanup loop (single iteration).
+// CHECK:      constant 1 : i32
+// CHECK-NEXT: "addi32"(%c100, %c100)
+// CHECK-NEXT: [[RES6:%[0-9]+]] = affine.for
+// CHECK-NEXT:   [[RES7:%[0-9]+]] = affine.for
+// CHECK-NEXT:     [[RES8:%[0-9]+]] = "bar"(%c100, {{.*}}, {{.*}}, {{.*}})
+// CHECK-NEXT:     affine.yield
+// CHECK-NEXT:   }
+// CHECK-NEXT:   [[RES17:%[0-9]+]] = affine.for
+// CHECK-NEXT:     [[RES18:%[0-9]+]] = "bar"(%c100, {{.*}}, {{.*}}, {{.*}})
+// CHECK-NEXT:     affine.yield [[RES18]] : i32
+// CHECK-NEXT:   }
+// CHECK-NEXT:   affine.yield [[RES17]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: "foo"(%c100, %{{.*}}, [[RES6]])
+// CHECK-NEXT: return
+
+// Nested inner loops, each with one or more iter_args. Yeild the same value
+// multiple times.
+// CHECK-LABEL: func @unroll_jam_iter_args_nested_yield
+func @unroll_jam_iter_args_nested_yield() {
+  affine.for %i = 0 to 101 {
+    %cst = constant 1 : i32
+    %x = "addi32"(%i, %i) : (index, index) -> i32
+    %red:3 = affine.for %j = 0 to 17 iter_args(%acc = %cst, %acc1 = %cst, %acc2 = %cst) -> (i32, i32, i32) {
+      %red1 = affine.for %k = 0 to 35 iter_args(%acc3 = %acc) -> (i32) {
+        %y = "bar"(%i, %j, %k, %acc3) : (index, index, index, i32) -> i32
+        affine.yield %y : i32
+      }
+      %red2:2 = affine.for %l = 0 to 36 iter_args(%acc4 = %acc1, %acc5 = %acc2) -> (i32, i32) {
+        %y = "bar1"(%i, %j, %l, %acc4, %acc5) : (index, index, index, i32, i32) -> i32
+        affine.yield %y, %y : i32, i32
+      }
+      affine.yield %red1, %red1, %red2#1 : i32, i32, i32
+    }
+    %w = "foo"(%i, %x, %red#0, %red#2) : (index, i32, i32, i32) -> i32
+  }
+  return
+}
+// CHECK:      affine.for [[IV0:%arg[0-9]+]] = 0 to 100 step 2 {
+// CHECK-NEXT:   [[CONST1:%[a-zA-Z0-9_]*]] = constant 1 : i32
+// CHECK-NEXT:   [[RES1:%[0-9]+]] = "addi32"([[IV0]], [[IV0]])
+// CHECK-NEXT:   [[INC:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   [[CONST2:%[a-zA-Z0-9_]*]] = constant 1 : i32
+// CHECK-NEXT:   [[RES2:%[0-9]+]] = "addi32"([[INC]], [[INC]])
+// CHECK-NEXT:   [[RES3:%[0-9]+]]:6 = affine.for [[IV1:%arg[0-9]+]] = 0 to 17 iter_args([[ACC1:%arg[0-9]+]] = [[CONST1]], [[ACC2:%arg[0-9]+]] = [[CONST1]],
+// CHECK-SAME:   [[ACC3:%arg[0-9]+]] = [[CONST1]], [[ACC4:%arg[0-9]+]] = [[CONST2]], [[ACC5:%arg[0-9]+]] = [[CONST2]], [[ACC6:%arg[0-9]+]] = [[CONST2]]) -> (i32, i32, i32, i32, i32, i32) {
+// CHECK-NEXT:     [[RES4:%[0-9]+]]:2 = affine.for [[IV2:%arg[0-9]+]] = 0 to 35 iter_args([[ACC7:%arg[0-9]+]] = [[ACC1]], [[ACC8:%arg[0-9]+]] = [[ACC4]]) -> (i32, i32) {
+// CHECK-NEXT:       [[RES5:%[0-9]+]] = "bar"([[IV0]], [[IV1]], [[IV2]], [[ACC7]])
+// CHECK-NEXT:       [[INC1:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:       [[RES6:%[0-9]+]] = "bar"([[INC1]], [[IV1]], [[IV2]], [[ACC8]])
+// CHECK-NEXT:       affine.yield [[RES5]], [[RES6]]
+// CHECK-NEXT:     }
+// CHECK-NEXT:     [[RES14:%[0-9]+]]:4 = affine.for [[IV3:%arg[0-9]+]] = 0 to 36 iter_args([[ACC13:%arg[0-9]+]] = [[ACC2]], [[ACC14:%arg[0-9]+]] = [[ACC3]],
+// CHECK-SAME:     [[ACC15:%arg[0-9]+]] = [[ACC5]], [[ACC16:%arg[0-9]+]] = [[ACC6]]) -> (i32, i32, i32, i32) {
+// CHECK-NEXT:       [[RES15:%[0-9]+]] = "bar1"([[IV0]], [[IV1]], [[IV3]], [[ACC13]], [[ACC14]])
+// CHECK-NEXT:       [[INC1:%[0-9]+]] = affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:       [[RES16:%[0-9]+]] = "bar1"([[INC1]], [[IV1]], [[IV3]], [[ACC15]], [[ACC16]])
+// CHECK-NEXT:       affine.yield [[RES15]], [[RES15]], [[RES16]], [[RES16]]
+// CHECK-NEXT:     }
+// CHECK-NEXT:     affine.yield [[RES4]]#0, [[RES4]]#0, [[RES14]]#1, [[RES4]]#1, [[RES4]]#1, [[RES14]]#3
+// CHECK-NEXT:   }
+// CHECK:        "foo"([[IV0]], [[RES1]], [[RES3]]#0, [[RES3]]#2)
+// CHECK-NEXT:   affine.apply [[$MAP_PLUS_1]]([[IV0]])
+// CHECK-NEXT:   "foo"({{.*}}, [[RES2]], [[RES3]]#3, [[RES3]]#5)
+// CHECK:      }
+// Cleanup loop (single iteration).
+// CHECK:      constant 1 : i32
+// CHECK-NEXT: "addi32"(%c100, %c100)
+// CHECK-NEXT: [[RES6:%[0-9]+]]:3 = affine.for
+// CHECK-NEXT:   [[RES7:%[0-9]+]] = affine.for
+// CHECK-NEXT:     [[RES8:%[0-9]+]] = "bar"(%c100, {{.*}}, {{.*}}, {{.*}})
+// CHECK-NEXT:     affine.yield [[RES8]] : i32
+// CHECK-NEXT:   }
+// CHECK-NEXT:   [[RES17:%[0-9]+]]:2 = affine.for
+// CHECK-NEXT:     [[RES18:%[0-9]+]] = "bar1"(%c100, {{.*}}, {{.*}}, {{.*}}, {{.*}})
+// CHECK-NEXT:     affine.yield [[RES18]], [[RES18]] : i32, i32
+// CHECK-NEXT:   }
+// CHECK-NEXT:   affine.yield [[RES7]], [[RES7]], [[RES17]]#1 : i32, i32, i32
+// CHECK-NEXT: }
+// CHECK-NEXT: "foo"(%c100, %{{.*}}, [[RES6]]#0, [[RES6]]#2)
+// CHECK-NEXT: return
+
+// CHECK-LABEL: func @unroll_jam_nested_iter_args_mulf
+// CHECK-SAME:  [[INIT0:%arg[0-9]+]]: f32, [[INIT1:%arg[0-9]+]]: f32
+func @unroll_jam_nested_iter_args_mulf(%arg0: memref<21x30xf32, 1>, %init : f32, %init1 : f32) {
+  %0 = affine.for %arg3 = 0 to 21 iter_args(%arg4 = %init) -> (f32) {
+    %1 = affine.for %arg5 = 0 to 30 iter_args(%arg6 = %init1) -> (f32) {
+      %3 = affine.load %arg0[%arg3, %arg5] : memref<21x30xf32, 1>
+      %4 = addf %arg6, %3 : f32
+      affine.yield %4 : f32
+    }
+    %2 = mulf %arg4, %1 : f32
+    affine.yield %2 : f32
+  }
+  return
+}
+
+// CHECK:      %[[CONST0:[a-zA-Z0-9_]*]] = constant 20 : index
+// CHECK-NEXT: [[RES:%[0-9]+]]:2 = affine.for %[[IV0:arg[0-9]+]] = 0 to 20 step 2 iter_args([[ACC0:%arg[0-9]+]] = [[INIT0]], [[ACC1:%arg[0-9]+]] = [[INIT0]]) -> (f32, f32) {
+// CHECK-NEXT:   [[RES1:%[0-9]+]]:2 = affine.for %[[IV1:arg[0-9]+]] = 0 to 30 iter_args([[ACC2:%arg[0-9]+]] = [[INIT1]], [[ACC3:%arg[0-9]+]] = [[INIT1]]) -> (f32, f32) {
+// CHECK-NEXT:     [[LOAD1:%[0-9]+]] = affine.load {{.*}}[%[[IV0]], %[[IV1]]]
+// CHECK-NEXT:     [[ADD1:%[0-9]+]] = addf [[ACC2]], [[LOAD1]] : f32
+// CHECK-NEXT:     %[[INC1:[0-9]+]] = affine.apply [[$MAP_PLUS_1]](%[[IV0]])
+// CHECK-NEXT:     [[LOAD2:%[0-9]+]] = affine.load {{.*}}[%[[INC1]], %[[IV1]]]
+// CHECK-NEXT:     [[ADD2:%[0-9]+]] = addf [[ACC3]], [[LOAD2]] : f32
+// CHECK-NEXT:     affine.yield [[ADD1]], [[ADD2]]
+// CHECK-NEXT:   }
+// CHECK-NEXT:   [[MUL1:%[0-9]+]] = mulf [[ACC0]], [[RES1]]#0 : f32
+// CHECK-NEXT:   affine.apply
+// CHECK-NEXT:   [[MUL2:%[0-9]+]] = mulf [[ACC1]], [[RES1]]#1 : f32
+// CHECK-NEXT:   affine.yield [[MUL1]], [[MUL2]]
+// CHECK-NEXT: }
+// Reduction op.
+// CHECK-NEXT: [[MUL3:%[0-9]+]] = mulf [[RES]]#0, [[RES]]#1 : f32
+// Cleanup loop (single iteration).
+// CHECK-NEXT: [[RES2:%[0-9]+]] = affine.for %[[IV2:arg[0-9]+]] = 0 to 30 iter_args([[ACC4:%arg[0-9]+]] = [[INIT1]]) -> (f32) {
+// CHECK-NEXT:   [[LOAD3:%[0-9]+]] = affine.load {{.*}}[%[[CONST0]], %[[IV2]]]
+// CHECK-NEXT:   [[ADD3:%[0-9]+]] = addf [[ACC4]], [[LOAD3]] : f32
+// CHECK-NEXT:   affine.yield [[ADD3]] : f32
+// CHECK-NEXT: }
+// CHECK-NEXT: [[MUL4:%[0-9]+]] = mulf [[MUL3]], [[RES2]] : f32
+// CHECK-NEXT: return
+
+// CHECK-LABEL: func @unroll_jam_iter_args_addi
+// CHECK-SAME:  [[INIT0:%arg[0-9]+]]: i32
+func @unroll_jam_iter_args_addi(%arg0: memref<21xi32, 1>, %init : i32) {
+  %0 = affine.for %arg3 = 0 to 21 iter_args(%arg4 = %init) -> (i32) {
+    %1 = affine.load %arg0[%arg3] : memref<21xi32, 1>
+    %2 = addi %arg4, %1 : i32
+    affine.yield %2 : i32
+  }
+  return
+}
+
+// CHECK:      %[[CONST0:[a-zA-Z0-9_]*]] = constant 20 : index
+// CHECK-NEXT: [[RES:%[0-9]+]]:2 = affine.for %[[IV0:arg[0-9]+]] = 0 to 20 step 2 iter_args([[ACC0:%arg[0-9]+]] = [[INIT0]], [[ACC1:%arg[0-9]+]] = [[INIT0]]) -> (i32, i32) {
+// CHECK-NEXT:   [[LOAD1:%[0-9]+]] = affine.load {{.*}}[%[[IV0]]]
+// CHECK-NEXT:   [[ADD1:%[0-9]+]] = addi [[ACC0]], [[LOAD1]] : i32
+// CHECK-NEXT:   %[[INC1:[0-9]+]] = affine.apply [[$MAP_PLUS_1]](%[[IV0]])
+// CHECK-NEXT:   [[LOAD2:%[0-9]+]] = affine.load {{.*}}[%[[INC1]]]
+// CHECK-NEXT:   [[ADD2:%[0-9]+]] = addi [[ACC1]], [[LOAD2]] : i32
+// CHECK-NEXT:   affine.yield [[ADD1]], [[ADD2]]
+// CHECK-NEXT: }
+// Reduction op.
+// CHECK-NEXT: [[ADD3:%[0-9]+]] = addi [[RES]]#0, [[RES]]#1 : i32
+// Cleanup loop (single iteration).
+// CHECK-NEXT: [[LOAD3:%[0-9]+]] = affine.load {{.*}}[%[[CONST0]]]
+// CHECK-NEXT: [[ADD4:%[0-9]+]] = addi [[ADD3]], [[LOAD3]] : i32
+// CHECK-NEXT: return


        


More information about the Mlir-commits mailing list