[Mlir-commits] [mlir] 545fa37 - [mlir] Affine: parallelize affine loops with reductions

Alex Zinenko llvmlistbot at llvm.org
Thu Apr 29 04:16:33 PDT 2021


Author: Alex Zinenko
Date: 2021-04-29T13:16:24+02:00
New Revision: 545fa37834ef6b5731444728c00e7a18d4f1aeed

URL: https://github.com/llvm/llvm-project/commit/545fa37834ef6b5731444728c00e7a18d4f1aeed
DIFF: https://github.com/llvm/llvm-project/commit/545fa37834ef6b5731444728c00e7a18d4f1aeed.diff

LOG: [mlir] Affine: parallelize affine loops with reductions

Introduce a basic support for parallelizing affine loops with reductions
expressed using iteration arguments. Affine parallelism detector now has a flag
to assume such reductions are parallel. The transformation handles a subset of
parallel reductions that are can be expressed using affine.parallel:
integer/float addition and multiplication. This requires to detect the
reduction operation since affine.parallel only supports a fixed set of
reduction operators.

Reviewed By: chelini, kumasento, bondhugula

Differential Revision: https://reviews.llvm.org/D101171

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineAnalysis.h
    mlir/include/mlir/Analysis/Utils.h
    mlir/include/mlir/Dialect/Affine/Passes.td
    mlir/include/mlir/Dialect/Affine/Utils.h
    mlir/lib/Analysis/AffineAnalysis.cpp
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp
    mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
    mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
    mlir/lib/Dialect/Affine/Utils/Utils.cpp
    mlir/test/Dialect/Affine/parallelize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h
index c3aaa40bda9a0..2a4bd478a513c 100644
--- a/mlir/include/mlir/Analysis/AffineAnalysis.h
+++ b/mlir/include/mlir/Analysis/AffineAnalysis.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H
 #define MLIR_ANALYSIS_AFFINE_ANALYSIS_H
 
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/SmallVector.h"
@@ -27,6 +28,25 @@ class AffineValueMap;
 class FlatAffineConstraints;
 class Operation;
 
+/// A description of a (parallelizable) reduction in an affine loop.
+struct LoopReduction {
+  /// Reduction kind.
+  AtomicRMWKind kind;
+
+  /// Position of the iteration argument that acts as accumulator.
+  unsigned iterArgPosition;
+
+  /// The value being reduced.
+  Value value;
+};
+
+/// Returns true if `forOp' is a parallel loop. If `parallelReductions` is
+/// provided, populates it with descriptors of the parallelizable reductions and
+/// treats them as not preventing parallelization.
+bool isLoopParallel(
+    AffineForOp forOp,
+    SmallVectorImpl<LoopReduction> *parallelReductions = nullptr);
+
 /// Returns in `affineApplyOps`, the sequence of those AffineApplyOp
 /// Operations that are reachable via a search starting from `operands` and
 /// ending at those operands that are not the result of an AffineApplyOp.

diff  --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index ccedd17e0e6c5..9f231dca44e03 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -354,9 +354,6 @@ unsigned getNumCommonSurroundingLoops(Operation &A, Operation &B);
 Optional<int64_t> getMemoryFootprintBytes(AffineForOp forOp,
                                           int memorySpace = -1);
 
-/// Returns true if `forOp' is a parallel loop.
-bool isLoopParallel(AffineForOp forOp);
-
 /// Simplify the integer set by simplifying the underlying affine expressions by
 /// flattening and some simple inference. Also, drop any duplicate constraints.
 /// Returns the simplified integer set. This method runs in time linear in the

diff  --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 77ba06483304c..45c28bbace4ae 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -123,6 +123,9 @@ def AffineParallelize : FunctionPass<"affine-parallelize"> {
     Option<"maxNested", "max-nested", "unsigned", /*default=*/"-1u",
            "Maximum number of nested parallel loops to produce. "
            "Defaults to unlimited (UINT_MAX).">,
+    Option<"parallelReductions", "parallel-reductions", "bool",
+           /*default=*/"false",
+           "Whether to parallelize reduction loops. Defaults to false.">
   ];
 }
 

diff  --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index be6985dfe4034..676b394398106 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -24,12 +24,17 @@ class AffineForOp;
 class AffineIfOp;
 class AffineParallelOp;
 struct LogicalResult;
+struct LoopReduction;
 class Operation;
 
 /// Replaces parallel affine.for op with 1-d affine.parallel op.
-/// mlir::isLoopParallel detect the parallel affine.for ops.
+/// mlir::isLoopParallel detects the parallel affine.for ops.
+/// Parallelizes the specified reductions. Parallelization will fail in presence
+/// of loop iteration arguments that are not listed in `parallelReductions`.
 /// There is no cost model currently used to drive this parallelization.
-void affineParallelize(AffineForOp forOp);
+LogicalResult
+affineParallelize(AffineForOp forOp,
+                  ArrayRef<LoopReduction> parallelReductions = {});
 
 /// Hoists out affine.if/else to as high as possible, i.e., past all invariant
 /// affine.fors/parallel's. Returns success if any hoisting happened; folded` is

diff  --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index b49e532eec4de..396ab89b5b526 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
@@ -21,6 +22,8 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -30,6 +33,131 @@ using namespace mlir;
 
 using llvm::dbgs;
 
+/// Returns true if `value` (transitively) depends on iteration arguments of the
+/// given `forOp`.
+static bool dependsOnIterArgs(Value value, AffineForOp forOp) {
+  // Compute the backward slice of the value.
+  SetVector<Operation *> slice;
+  getBackwardSlice(value, &slice,
+                   [&](Operation *op) { return !forOp->isAncestor(op); });
+
+  // Check that none of the operands of the operations in the backward slice are
+  // loop iteration arguments, and neither is the value itself.
+  auto argRange = forOp.getRegionIterArgs();
+  llvm::SmallPtrSet<Value, 8> iterArgs(argRange.begin(), argRange.end());
+  if (iterArgs.contains(value))
+    return true;
+
+  for (Operation *op : slice)
+    for (Value operand : op->getOperands())
+      if (iterArgs.contains(operand))
+        return true;
+
+  return false;
+}
+
+/// Get the value that is being reduced by `pos`-th reduction in the loop if
+/// such a reduction can be performed by affine parallel loops. This assumes
+/// floating-point operations are commutative. On success, `kind` will be the
+/// reduction kind suitable for use in affine parallel loop builder. If the
+/// reduction is not supported, returns null.
+static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
+                                   AtomicRMWKind &kind) {
+  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->back());
+  Value yielded = yieldOp.operands()[pos];
+  Operation *definition = yielded.getDefiningOp();
+  if (!definition)
+    return nullptr;
+  if (!forOp.getRegionIterArgs()[pos].hasOneUse())
+    return nullptr;
+
+  Optional<AtomicRMWKind> maybeKind =
+      TypeSwitch<Operation *, Optional<AtomicRMWKind>>(definition)
+          .Case<AddFOp>([](Operation *) { return AtomicRMWKind::addf; })
+          .Case<MulFOp>([](Operation *) { return AtomicRMWKind::mulf; })
+          .Case<AddIOp>([](Operation *) { return AtomicRMWKind::addi; })
+          .Case<MulIOp>([](Operation *) { return AtomicRMWKind::muli; })
+          .Default([](Operation *) -> Optional<AtomicRMWKind> {
+            // TODO: AtomicRMW supports other kinds of reductions this is
+            // currently not detecting, add those when the need arises.
+            return llvm::None;
+          });
+  if (!maybeKind)
+    return nullptr;
+
+  kind = *maybeKind;
+  if (definition->getOperand(0) == forOp.getRegionIterArgs()[pos] &&
+      !dependsOnIterArgs(definition->getOperand(1), forOp))
+    return definition->getOperand(1);
+  if (definition->getOperand(1) == forOp.getRegionIterArgs()[pos] &&
+      !dependsOnIterArgs(definition->getOperand(0), forOp))
+    return definition->getOperand(0);
+
+  return nullptr;
+}
+
+/// Returns true if `forOp' is a parallel loop. If `parallelReductions` is
+/// provided, populates it with descriptors of the parallelizable reductions and
+/// treats them as not preventing parallelization.
+bool mlir::isLoopParallel(AffineForOp forOp,
+                          SmallVectorImpl<LoopReduction> *parallelReductions) {
+  unsigned numIterArgs = forOp.getNumIterOperands();
+
+  // Loop is not parallel if it has SSA loop-carried dependences and reduction
+  // detection is not requested.
+  if (numIterArgs > 0 && !parallelReductions)
+    return false;
+
+  // Find supported reductions of requested.
+  if (parallelReductions) {
+    parallelReductions->reserve(forOp.getNumIterOperands());
+    for (unsigned i = 0; i < numIterArgs; ++i) {
+      AtomicRMWKind kind;
+      if (Value value = getSupportedReduction(forOp, i, kind))
+        parallelReductions->emplace_back(LoopReduction{kind, i, value});
+    }
+
+    // Return later to allow for identifying all parallel reductions even if the
+    // loop is not parallel.
+    if (parallelReductions->size() != numIterArgs)
+      return false;
+  }
+
+  // Collect all load and store ops in loop nest rooted at 'forOp'.
+  SmallVector<Operation *, 8> loadAndStoreOps;
+  auto walkResult = forOp.walk([&](Operation *op) -> WalkResult {
+    if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
+      loadAndStoreOps.push_back(op);
+    else if (!isa<AffineForOp, AffineYieldOp, AffineIfOp>(op) &&
+             !MemoryEffectOpInterface::hasNoEffect(op))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+
+  // Stop early if the loop has unknown ops with side effects.
+  if (walkResult.wasInterrupted())
+    return false;
+
+  // Dep check depth would be number of enclosing loops + 1.
+  unsigned depth = getNestingDepth(forOp) + 1;
+
+  // Check dependences between all pairs of ops in 'loadAndStoreOps'.
+  for (auto *srcOp : loadAndStoreOps) {
+    MemRefAccess srcAccess(srcOp);
+    for (auto *dstOp : loadAndStoreOps) {
+      MemRefAccess dstAccess(dstOp);
+      FlatAffineConstraints dependenceConstraints;
+      DependenceResult result = checkMemrefAccessDependence(
+          srcAccess, dstAccess, depth, &dependenceConstraints,
+          /*dependenceComponents=*/nullptr);
+      if (result.value != DependenceResult::NoDependence)
+        return false;
+    }
+  }
+  return true;
+}
+
 /// Returns the sequence of AffineApplyOp Operations operation in
 /// 'affineApplyOps', which are reachable via a search starting from 'operands',
 /// and ending at operands which are not defined by AffineApplyOps.

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index a4b8ccfc7ad14..e87ecdac2d6ca 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -1268,49 +1268,6 @@ void mlir::getSequentialLoops(AffineForOp forOp,
   });
 }
 
-/// Returns true if 'forOp' is parallel.
-bool mlir::isLoopParallel(AffineForOp forOp) {
-  // Loop is not parallel if it has SSA loop-carried dependences.
-  // TODO: Conditionally support reductions and other loop-carried dependences
-  // that could be handled in the context of a parallel loop.
-  if (forOp.getNumIterOperands() > 0)
-    return false;
-
-  // Collect all load and store ops in loop nest rooted at 'forOp'.
-  SmallVector<Operation *, 8> loadAndStoreOpInsts;
-  auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult {
-    if (isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst))
-      loadAndStoreOpInsts.push_back(opInst);
-    else if (!isa<AffineForOp, AffineYieldOp, AffineIfOp>(opInst) &&
-             !MemoryEffectOpInterface::hasNoEffect(opInst))
-      return WalkResult::interrupt();
-
-    return WalkResult::advance();
-  });
-
-  // Stop early if the loop has unknown ops with side effects.
-  if (walkResult.wasInterrupted())
-    return false;
-
-  // Dep check depth would be number of enclosing loops + 1.
-  unsigned depth = getNestingDepth(forOp) + 1;
-
-  // Check dependences between all pairs of ops in 'loadAndStoreOpInsts'.
-  for (auto *srcOpInst : loadAndStoreOpInsts) {
-    MemRefAccess srcAccess(srcOpInst);
-    for (auto *dstOpInst : loadAndStoreOpInsts) {
-      MemRefAccess dstAccess(dstOpInst);
-      FlatAffineConstraints dependenceConstraints;
-      DependenceResult result = checkMemrefAccessDependence(
-          srcAccess, dstAccess, depth, &dependenceConstraints,
-          /*dependenceComponents=*/nullptr);
-      if (result.value != DependenceResult::NoDependence)
-        return false;
-    }
-  }
-  return true;
-}
-
 IntegerSet mlir::simplifyIntegerSet(IntegerSet set) {
   FlatAffineConstraints fac(set);
   if (fac.isEmpty())

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp
index 55ae51376db7c..62519908a248f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "PassDetail.h"
+#include "mlir/Analysis/AffineAnalysis.h"
 #include "mlir/Analysis/AffineStructures.h"
 #include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Analysis/Utils.h"
@@ -33,6 +34,17 @@ namespace {
 struct AffineParallelize : public AffineParallelizeBase<AffineParallelize> {
   void runOnFunction() override;
 };
+
+/// Descriptor of a potentially parallelizable loop.
+struct ParallelizationCandidate {
+  ParallelizationCandidate(AffineForOp l, SmallVector<LoopReduction> &&r)
+      : loop(l), reductions(std::move(r)) {}
+
+  /// The potentially parallelizable loop.
+  AffineForOp loop;
+  /// Desciprtors of reductions that can be parallelized in the loop.
+  SmallVector<LoopReduction> reductions;
+};
 } // namespace
 
 void AffineParallelize::runOnFunction() {
@@ -41,14 +53,16 @@ void AffineParallelize::runOnFunction() {
   // The walker proceeds in post-order, but we need to process outer loops first
   // to control the number of outer parallel loops, so push candidate loops to
   // the front of a deque.
-  std::deque<AffineForOp> parallelizableLoops;
+  std::deque<ParallelizationCandidate> parallelizableLoops;
   f.walk([&](AffineForOp loop) {
-    if (isLoopParallel(loop))
-      parallelizableLoops.push_front(loop);
+    SmallVector<LoopReduction> reductions;
+    if (isLoopParallel(loop, parallelReductions ? &reductions : nullptr))
+      parallelizableLoops.emplace_back(loop, std::move(reductions));
   });
 
-  for (AffineForOp loop : parallelizableLoops) {
+  for (const ParallelizationCandidate &candidate : parallelizableLoops) {
     unsigned numParentParallelOps = 0;
+    AffineForOp loop = candidate.loop;
     for (Operation *op = loop->getParentOp();
          op != nullptr && !op->hasTrait<OpTrait::AffineScope>();
          op = op->getParentOp()) {
@@ -56,8 +70,15 @@ void AffineParallelize::runOnFunction() {
         ++numParentParallelOps;
     }
 
-    if (numParentParallelOps < maxNested)
-      affineParallelize(loop);
+    if (numParentParallelOps < maxNested) {
+      if (failed(affineParallelize(loop, candidate.reductions))) {
+        LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] failed to parallelize\n"
+                                << loop);
+      }
+    } else {
+      LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] too many nested loops\n"
+                              << loop);
+    }
   }
 }
 

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index a8fa831c1a1e3..551a8d44042ed 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -12,16 +12,16 @@
 //===----------------------------------------------------------------------===//
 
 #include "PassDetail.h"
+#include "mlir/Analysis/AffineAnalysis.h"
 #include "mlir/Analysis/LoopAnalysis.h"
 #include "mlir/Analysis/NestedMatcher.h"
-#include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"
 #include "mlir/IR/BlockAndValueMapping.h"
-#include "llvm/Support/Debug.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/Support/Debug.h"
 
 using namespace mlir;
 using namespace vector;

diff  --git a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
index e4a5d0bbd9f15..3bc37cfa3ba23 100644
--- a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt
@@ -6,5 +6,6 @@ add_mlir_dialect_library(MLIRAffineUtils
 
   LINK_LIBS PUBLIC
   MLIRAffine
+  MLIRAnalysis
   MLIRTransformUtils
   )

diff  --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 522cfd7fca950..935abdb477bdb 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -12,6 +12,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -130,8 +132,17 @@ static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
 }
 
 /// Replace affine.for with a 1-d affine.parallel and clone the former's body
-/// into the latter while remapping values.
-void mlir::affineParallelize(AffineForOp forOp) {
+/// into the latter while remapping values. Parallelizes the specified
+/// reductions. Parallelization will fail in presence of loop iteration
+/// arguments that are not listed in `parallelReductions`.
+LogicalResult
+mlir::affineParallelize(AffineForOp forOp,
+                        ArrayRef<LoopReduction> parallelReductions) {
+  // Fail early if there are iter arguments that are not reductions.
+  unsigned numReductions = parallelReductions.size();
+  if (numReductions != forOp.getNumIterOperands())
+    return failure();
+
   Location loc = forOp.getLoc();
   OpBuilder outsideBuilder(forOp);
 
@@ -148,7 +159,7 @@ void mlir::affineParallelize(AffineForOp forOp) {
   if (needsMax || needsMin) {
     if (forOp->getParentOp() &&
         !forOp->getParentOp()->hasTrait<OpTrait::AffineScope>())
-      return;
+      return failure();
 
     identityMap = AffineMap::getMultiDimIdentityMap(1, loc->getContext());
   }
@@ -168,12 +179,46 @@ void mlir::affineParallelize(AffineForOp forOp) {
   }
 
   // Creating empty 1-D affine.parallel op.
+  auto reducedValues = llvm::to_vector<4>(llvm::map_range(
+      parallelReductions, [](const LoopReduction &red) { return red.value; }));
+  auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
+      parallelReductions, [](const LoopReduction &red) { return red.kind; }));
   AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
-      loc, llvm::None, llvm::None, lowerBoundMap, lowerBoundOperands,
-      upperBoundMap, upperBoundOperands);
-  // Steal the body of the old affine for op and erase it.
+      loc, ValueRange(reducedValues).getTypes(), reductionKinds, lowerBoundMap,
+      lowerBoundOperands, upperBoundMap, upperBoundOperands);
+  // Steal the body of the old affine for op.
   newPloop.region().takeBody(forOp.region());
+  Operation *yieldOp = &newPloop.getBody()->back();
+
+  // Handle the initial values of reductions because the parallel loop always
+  // starts from the neutral value.
+  SmallVector<Value> newResults;
+  newResults.reserve(numReductions);
+  for (unsigned i = 0; i < numReductions; ++i) {
+    Value init = forOp.getIterOperands()[i];
+    // This works because we are only handling single-op reductions at the
+    // moment. A switch on reduction kind or a mechanism to collect operations
+    // participating in the reduction will be necessary for multi-op reductions.
+    Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp();
+    assert(reductionOp && "yielded value is expected to be produced by an op");
+    outsideBuilder.getInsertionBlock()->getOperations().splice(
+        outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(),
+        reductionOp);
+    reductionOp->setOperands({init, newPloop->getResult(i)});
+    forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0));
+  }
+
+  // Update the loop terminator to yield reduced values bypassing the reduction
+  // operation itself (now moved outside of the loop) and erase the block
+  // arguments that correspond to reductions. Note that the loop always has one
+  // "main" induction variable whenc coming from a non-parallel for.
+  unsigned numIVs = 1;
+  yieldOp->setOperands(reducedValues);
+  newPloop.getBody()->eraseArguments(
+      llvm::to_vector<4>(llvm::seq<unsigned>(numIVs, numReductions + numIVs)));
+
   forOp.erase();
+  return success();
 }
 
 // Returns success if any hoisting happened.

diff  --git a/mlir/test/Dialect/Affine/parallelize.mlir b/mlir/test/Dialect/Affine/parallelize.mlir
index ca72e67f91080..9bd479c91a000 100644
--- a/mlir/test/Dialect/Affine/parallelize.mlir
+++ b/mlir/test/Dialect/Affine/parallelize.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize| FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize | FileCheck %s
 // RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize='max-nested=1' | FileCheck --check-prefix=MAX-NESTED %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -affine-parallelize='parallel-reductions=1' | FileCheck --check-prefix=REDUCE %s
 
 // CHECK-LABEL:    func @reduce_window_max() {
 func @reduce_window_max() {
@@ -159,24 +160,34 @@ func @max_nested(%m: memref<?x?xf32>, %lb0: index, %lb1: index,
   return
 }
 
-// CHECK-LABEL: @unsupported_iter_args
-func @unsupported_iter_args(%in: memref<10xf32>) {
+// CHECK-LABEL: @iter_args
+// REDUCE-LABEL: @iter_args
+func @iter_args(%in: memref<10xf32>) {
+  // REDUCE: %[[init:.*]] = constant
   %cst = constant 0.000000e+00 : f32
   // CHECK-NOT: affine.parallel
+  // REDUCE: %[[reduced:.*]] = affine.parallel (%{{.*}}) = (0) to (10) reduce ("addf")
   %final_red = affine.for %i = 0 to 10 iter_args(%red_iter = %cst) -> (f32) {
+    // REDUCE: %[[red_value:.*]] = affine.load
     %ld = affine.load %in[%i] : memref<10xf32>
+    // REDUCE-NOT: addf
     %add = addf %red_iter, %ld : f32
+    // REDUCE: affine.yield %[[red_value]]
     affine.yield %add : f32
   }
+  // REDUCE: addf %[[init]], %[[reduced]]
   return
 }
 
-// CHECK-LABEL: @unsupported_nested_iter_args
-func @unsupported_nested_iter_args(%in: memref<20x10xf32>) {
+// CHECK-LABEL: @nested_iter_args
+// REDUCE-LABEL: @nested_iter_args
+func @nested_iter_args(%in: memref<20x10xf32>) {
   %cst = constant 0.000000e+00 : f32
   // CHECK: affine.parallel
   affine.for %i = 0 to 20 {
-    // CHECK: affine.for
+    // CHECK-NOT: affine.parallel
+    // REDUCE: affine.parallel
+    // REDUCE: reduce ("addf")
     %final_red = affine.for %j = 0 to 10 iter_args(%red_iter = %cst) -> (f32) {
       %ld = affine.load %in[%i, %j] : memref<20x10xf32>
       %add = addf %red_iter, %ld : f32
@@ -185,3 +196,43 @@ func @unsupported_nested_iter_args(%in: memref<20x10xf32>) {
   }
   return
 }
+
+// REDUCE-LABEL: @strange_butterfly
+func @strange_butterfly() {
+  %cst1 = constant 0.0 : f32
+  %cst2 = constant 1.0 : f32
+  // REDUCE-NOT: affine.parallel
+  affine.for %i = 0 to 10 iter_args(%it1 = %cst1, %it2 = %cst2) -> (f32, f32) {
+    %0 = addf %it1, %it2 : f32
+    affine.yield %0, %0 : f32, f32
+  }
+  return
+}
+
+// An iter arg is used more than once. This is not a simple reduction and
+// should not be parallelized.
+// REDUCE-LABEL: @repeated_use
+func @repeated_use() {
+  %cst1 = constant 0.0 : f32
+  // REDUCE-NOT: affine.parallel
+  affine.for %i = 0 to 10 iter_args(%it1 = %cst1) -> (f32) {
+    %0 = addf %it1, %it1 : f32
+    affine.yield %0 : f32
+  }
+  return
+}
+
+// An iter arg is used in the chain of operations defining the value being
+// reduced, this is not a simple reduction and should not be parallelized.
+// REDUCE-LABEL: @use_in_backward_slice
+func @use_in_backward_slice() {
+  %cst1 = constant 0.0 : f32
+  %cst2 = constant 1.0 : f32
+  // REDUCE-NOT: affine.parallel
+  affine.for %i = 0 to 10 iter_args(%it1 = %cst1, %it2 = %cst2) -> (f32, f32) {
+    %0 = "test.some_modification"(%it2) : (f32) -> f32
+    %1 = addf %it1, %0 : f32
+    affine.yield %1, %1 : f32, f32
+  }
+  return
+}


        


More information about the Mlir-commits mailing list