[Mlir-commits] [mlir] [mlir][scf] Add parallelLoopUnrollByFactors() (PR #163806)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 23 04:04:49 PDT 2025


https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/163806

>From 8ac30d809b0adf6fe959d303bd37fd2cbaa78fe4 Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Date: Thu, 16 Oct 2025 16:22:10 +0100
Subject: [PATCH] [mlir][scf] Add parallelLoopUnrollByFactors()

- In the SCF Utils, add the parallelLoopUnrollByFactors() function
  to unroll scf::ParallelOp loops according to the specified unroll factors
- Add a test pass "TestParallelLoopUnrolling" and the related LIT test
- Expose mlir::parallelLoopUnrollByFactors(), mlir::generateUnrolledLoop(),
  and mlir::scf::computeUbMinusLb() functions in the
  mlir/Dialect/SCF/Utils/Utils.h header to make thme available to other passes.
- In mlir::generateUnrolledLoop(), add also an optional `IRMapping *clonedToSrcOpsMap`
  argument to map the new cloned operations to their original ones.
  In the function body, change the default `AnnotateFn` type to `static const`
  to silence potential warnings about dangling references when a function_ref
  is assigned  to a variable with automatic storage.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Change-Id: I13ffb7b622f001a1ec12989e88bbba643d5dc6b5
---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  34 ++++
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  19 +--
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 157 ++++++++++++++++--
 .../Dialect/SCF/parallel-loop-unroll.mlir     | 148 +++++++++++++++++
 mlir/test/lib/Dialect/SCF/CMakeLists.txt      |   1 +
 .../Dialect/SCF/TestParallelLoopUnrolling.cpp |  86 ++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 7 files changed, 412 insertions(+), 35 deletions(-)
 create mode 100644 mlir/test/Dialect/SCF/parallel-loop-unroll.mlir
 create mode 100644 mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index ecd829ed14add..ed0298f1c5214 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -221,6 +221,40 @@ FailureOr<scf::ForallOp> normalizeForallOp(RewriterBase &rewriter,
 /// 4. Each region iter arg and result has exactly one use
 bool isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops);
 
+/// Generate unrolled copies of an scf loop's 'loopBodyBlock', with 'iterArgs'
+/// and 'yieldedValues' as the block arguments and yielded values of the loop.
+/// The content of the loop body is replicated 'unrollFactor' times, calling
+/// 'ivRemapFn' to remap 'iv' for each unrolled body. If specified, annotates
+/// the Ops in each unrolled iteration using annotateFn. If provided,
+/// 'clonedToSrcOpsMap' is populated with the mappings from the cloned ops to
+/// the original op.
+void generateUnrolledLoop(
+    Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
+    function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
+    function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
+    ValueRange iterArgs, ValueRange yieldedValues,
+    IRMapping *clonedToSrcOpsMap = nullptr);
+
+/// Unrolls this scf::Parallel loop by the specified unroll factors. Returns the
+/// unrolled loop if the unroll succeded; otherwise
+/// returns failure if the loop cannot be unrolled either due to restrictions or
+/// due to invalid unroll factors. Requires positive loop bounds and step. If
+/// specified, annotates the Ops in each unrolled iteration by applying
+/// `annotateFn`. If provided,
+/// 'clonedToSrcOpsMap' is populated with the mappings from the cloned ops to
+/// the original op.
+FailureOr<scf::ParallelOp> parallelLoopUnrollByFactors(
+    scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors,
+    RewriterBase &rewriter,
+    function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr,
+    IRMapping *clonedToSrcOpsMap = nullptr);
+
+namespace scf {
+/// Helper function to compute the difference between two values. This is used
+/// by the loop implementations to compute the trip count.
+std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub, bool isSigned);
+} // namespace scf
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index a9da6c2c8320a..c9ec087392047 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IRMapping.h"
@@ -110,24 +111,6 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
   return nullptr;
 }
 
-/// Helper function to compute the difference between two values. This is used
-/// by the loop implementations to compute the trip count.
-static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
-                                                    bool isSigned) {
-  llvm::APSInt diff;
-  auto addOp = ub.getDefiningOp<arith::AddIOp>();
-  if (!addOp)
-    return std::nullopt;
-  if ((isSigned && !addOp.hasNoSignedWrap()) ||
-      (!isSigned && !addOp.hasNoUnsignedWrap()))
-    return std::nullopt;
-
-  if (addOp.getLhs() != lb ||
-      !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
-    return std::nullopt;
-  return diff;
-}
-
 //===----------------------------------------------------------------------===//
 // ExecuteRegionOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 10eae8906ce31..ccb9784a0445c 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -291,47 +291,61 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
   return arith::DivUIOp::create(builder, loc, sum, divisor);
 }
 
-/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
-/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
-/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
-/// unrolled iteration using annotateFn.
-static void generateUnrolledLoop(
-    Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
+void mlir::generateUnrolledLoop(
+    Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
     function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
-    ValueRange iterArgs, ValueRange yieldedValues) {
+    ValueRange iterArgs, ValueRange yieldedValues,
+    IRMapping *clonedToSrcOpsMap) {
+
+  // check if the op was cloned from another source op, and return it if found
+  // (or the same op if not found)
+  auto findOriginalSrcOp =
+      [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * {
+    Operation *srcOp = op;
+    // if the source op derives from another op: traverse the chain to find the
+    // original source op
+    while (srcOp && clonedToSrcOpsMap.contains(srcOp))
+      srcOp = clonedToSrcOpsMap.lookup(srcOp);
+    return srcOp;
+  };
+
   // Builder to insert unrolled bodies just before the terminator of the body of
-  // 'forOp'.
+  // the loop.
   auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
 
-  constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
+  static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
   if (!annotateFn)
-    annotateFn = defaultAnnotateFn;
+    annotateFn = noopAnnotateFn;
 
   // Keep a pointer to the last non-terminator operation in the original block
   // so that we know what to clone (since we are doing this in-place).
   Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
 
-  // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
+  // Unroll the contents of the loop body (append unrollFactor - 1 additional
+  // copies).
   SmallVector<Value, 4> lastYielded(yieldedValues);
 
   for (unsigned i = 1; i < unrollFactor; i++) {
-    IRMapping operandMap;
-
     // Prepare operand map.
+    IRMapping operandMap;
     operandMap.map(iterArgs, lastYielded);
 
     // If the induction variable is used, create a remapping to the value for
     // this unrolled instance.
-    if (!forOpIV.use_empty()) {
-      Value ivUnroll = ivRemapFn(i, forOpIV, builder);
-      operandMap.map(forOpIV, ivUnroll);
+    if (!iv.use_empty()) {
+      Value ivUnroll = ivRemapFn(i, iv, builder);
+      operandMap.map(iv, ivUnroll);
     }
 
     // Clone the original body of 'forOp'.
     for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
-      Operation *clonedOp = builder.clone(*it, operandMap);
+      Operation *srcOp = &(*it);
+      Operation *clonedOp = builder.clone(*srcOp, operandMap);
       annotateFn(i, clonedOp, builder);
+      if (clonedToSrcOpsMap)
+        clonedToSrcOpsMap->map(clonedOp,
+                               findOriginalSrcOp(srcOp, *clonedToSrcOpsMap));
     }
 
     // Update yielded values.
@@ -1544,3 +1558,112 @@ bool mlir::isPerfectlyNestedForLoops(
   }
   return true;
 }
+
+std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
+                                                        bool isSigned) {
+  llvm::APSInt diff;
+  auto addOp = ub.getDefiningOp<arith::AddIOp>();
+  if (!addOp)
+    return std::nullopt;
+  if ((isSigned && !addOp.hasNoSignedWrap()) ||
+      (!isSigned && !addOp.hasNoUnsignedWrap()))
+    return std::nullopt;
+
+  if (addOp.getLhs() != lb ||
+      !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
+    return std::nullopt;
+  return diff;
+}
+
+llvm::SmallVector<int64_t>
+getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
+  std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
+  std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
+  std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
+  if (!loBnds || !upBnds || !steps)
+    return {};
+  llvm::SmallVector<int64_t> tripCounts;
+  for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
+    std::optional<llvm::APInt> numIter = constantTripCount(
+        lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
+    if (!numIter)
+      return {};
+    tripCounts.push_back(numIter->getSExtValue());
+  }
+  return tripCounts;
+}
+
+FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
+    scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors,
+    RewriterBase &rewriter,
+    function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
+    IRMapping *clonedToSrcOpsMap) {
+  const unsigned numLoops = op.getNumLoops();
+  assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) &&
+         "Expected positive unroll factors");
+  assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) &&
+         "Expected non-empty unroll factors of size <= to the number of loops");
+
+  // Bail out if no valid unroll factors were provided
+  if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; }))
+    return rewriter.notifyMatchFailure(
+        op, "Unrolling not applied if all factors are 1");
+
+  // Return if the loop body is empty.
+  if (llvm::hasSingleElement(op.getBody()->getOperations()))
+    return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body");
+
+  // If the provided unroll factors do not cover all the loop dims, they are
+  // applied to the inner loop dimensions.
+  const unsigned firstLoopDimIdx = numLoops - unrollFactors.size();
+
+  // Make sure that the unroll factors divide the iteration space evenly
+  const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op);
+  if (tripCounts.empty())
+    return rewriter.notifyMatchFailure(
+        op, "Failed to compute constant trip counts for the loop. Note that "
+            "dynamic loop sizes are not supported.");
+  for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
+    const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
+    if (tripCounts[dimIdx] % unrollFactor)
+      return rewriter.notifyMatchFailure(
+          op, "Unroll factors don't divide the iteration space evenly");
+  }
+
+  std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps();
+  if (!maybeFoldSteps)
+    return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps");
+  llvm::SmallVector<size_t> steps{};
+  for (auto step : *maybeFoldSteps)
+    steps.push_back(static_cast<size_t>(*getConstantIntValue(step)));
+
+  for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
+    const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
+    const size_t origStep = steps[dimIdx];
+    const int64_t newStep = origStep * unrollFactor;
+    IRMapping clonedToSrcOpsMap;
+
+    ValueRange iterArgs = ValueRange(op.getRegionIterArgs());
+    auto yieldedValues = op.getBody()->getTerminator()->getOperands();
+
+    generateUnrolledLoop(
+        op.getBody(), op.getInductionVars()[dimIdx], unrollFactor,
+        [&](unsigned i, Value iv, OpBuilder b) {
+          // iv' = iv + step * i;
+          const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i);
+          const auto map =
+              b.getDimIdentityMap().dropResult(0).insertResult(expr, 0);
+          return b.create<affine::AffineApplyOp>(iv.getLoc(), map,
+                                                 ValueRange{iv});
+        },
+        /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap);
+
+    // Update loop step
+    auto prevInsertPoint = rewriter.saveInsertionPoint();
+    rewriter.setInsertionPoint(op);
+    op.getStepMutable()[dimIdx].assign(
+        rewriter.create<arith::ConstantIndexOp>(op.getLoc(), newStep));
+    rewriter.restoreInsertionPoint(prevInsertPoint);
+  }
+  return op;
+}
diff --git a/mlir/test/Dialect/SCF/parallel-loop-unroll.mlir b/mlir/test/Dialect/SCF/parallel-loop-unroll.mlir
new file mode 100644
index 0000000000000..1dc6ebb01f0c2
--- /dev/null
+++ b/mlir/test/Dialect/SCF/parallel-loop-unroll.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt %s -test-parallel-loop-unrolling='unroll-factors=1,2' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-parallel-loop-unrolling='unroll-factors=1,2 loop-depth=1' -split-input-file | FileCheck %s --check-prefix CHECK-UNROLL-INNER
+// RUN: mlir-opt %s -test-parallel-loop-unrolling='unroll-factors=3,1' -split-input-file | FileCheck %s --check-prefix CHECK-UNROLL-BY-3
+
+func.func @unroll_simple_parallel_loop(%arg0: memref<1x16x12xf32>, %arg1: memref<1x16x12xf32>) {
+  %c12 = arith.constant 12 : index
+  %c16 = arith.constant 16 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c1, %c16, %c12) step (%c1, %c1, %c1) {
+    %read = memref.load %arg0[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+    memref.store %read, %arg1[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+    scf.reduce
+  }
+  return
+}
+
+// CHECK-LABEL:   func @unroll_simple_parallel_loop
+// CHECK-SAME:     ([[ARG0:%.*]]: memref<1x16x12xf32>, [[ARG1:%.*]]: memref<1x16x12xf32>)
+// CHECK:           scf.parallel ([[IV0:%.*]], [[IV1:%.*]], [[IV2:%.*]]) = (%c0, %c0, %c0) to (%c1, %c16, %c12) step (%c1, %c1{{.*}}, %c2)
+// CHECK:             [[LOADED1:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK:             memref.store [[LOADED1]], [[ARG1]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK:             [[UNR_Y:%.*]] = affine.apply {{.*}}([[IV2]])
+// CHECK:             [[LOADED2:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[UNR_Y]]] : memref<1x16x12xf32>
+// CHECK:             memref.store [[LOADED2]], [[ARG1]][[[IV0]], [[IV1]], [[UNR_Y]]] : memref<1x16x12xf32>
+
+// -----
+
+func.func @negative_unroll_factors_dont_divide_evenly(%arg0: memref<1x16x12xf32>, %arg1: memref<1x16x12xf32>) {
+  %c12 = arith.constant 12 : index
+  %c16 = arith.constant 16 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c1, %c16, %c12) step (%c1, %c1, %c1) {
+    %read = memref.load %arg0[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+    memref.store %read, %arg1[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+    scf.reduce
+  }
+  return
+}
+
+// CHECK-UNROLL-BY-3-LABEL:   func @negative_unroll_factors_dont_divide_evenly
+// CHECK-UNROLL-BY-3-SAME:     ([[ARG0:%.*]]: memref<1x16x12xf32>, [[ARG1:%.*]]: memref<1x16x12xf32>)
+// CHECK-UNROLL-BY-3:           scf.parallel ([[IV0:%.*]], [[IV1:%.*]], [[IV2:%.*]]) = (%c0, %c0, %c0) to (%c1, %c16, %c12) step (%c1, %c1, %c1)
+// CHECK-UNROLL-BY-3:           [[LOADED:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK-UNROLL-BY-3:           memref.store [[LOADED]], [[ARG1]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK-UNROLL-BY-3-NOT:       affine.apply
+// CHECK-UNROLL-BY-3-NOT:       memref.load
+// CHECK-UNROLL-BY-3-NOT:       memref.store
+
+// -----
+
+func.func @unroll_outer_nested_parallel_loop(%arg0: memref<5x16x12x4x4xf32>, %arg2: memref<5x16x12x4x4xf32>) {
+  %c4 = arith.constant 4 : index
+  %c12 = arith.constant 12 : index
+  %c16 = arith.constant 16 : index
+  %c5 = arith.constant 5 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%c5, %c16, %c12) step (%c1, %c1, %c1) {
+    scf.parallel (%arg6, %arg7) = (%c0, %c0) to (%c4, %c4) step (%c1, %c1) {
+      %0 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg4, %arg6)
+      %1 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg5, %arg7)
+      %subv_in = memref.subview %arg0[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>>
+      %subv_out = memref.subview %arg2[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>>
+      linalg.erf ins(%subv_in : memref<4x4xf32, strided<[4, 1], offset: ?>>) outs(%subv_out : memref<4x4xf32, strided<[4, 1], offset: ?>>)
+      scf.reduce
+    }
+    scf.reduce
+  }
+  return
+}
+
+// CHECK-UNROLL-BY-3-LABEL:   func @unroll_outer_nested_parallel_loop
+// CHECK-LABEL:   func @unroll_outer_nested_parallel_loop
+// CHECK-SAME:     ([[ARG0:%.*]]: memref<5x16x12x4x4xf32>, [[ARG1:%.*]]: memref<5x16x12x4x4xf32>)
+// CHECK:           scf.parallel ([[OUTV0:%.*]], [[OUTV1:%.*]], [[OUTV2:%.*]]) = (%c0, %c0, %c0) to (%c5, %c16, %c12) step (%c1, %c1{{.*}}, %c2)
+// CHECK:             scf.parallel ([[INV0:%.*]], [[INV1:%.*]]) = (%c0, %c0) to (%c4, %c4) step (%c1, %c1)
+// CHECK:               affine.apply {{.*}}([[OUTV1]], [[INV0]])
+// CHECK:               affine.apply {{.*}}([[OUTV2]], [[INV1]])
+// CHECK:               linalg.erf
+
+// CHECK:             [[UNR_OUTV2:%.*]] = affine.apply {{.*}}([[OUTV2]])
+// CHECK:             scf.parallel ([[INV0B:%.*]], [[INV1B:%.*]]) = (%c0, %c0) to (%c4, %c4) step (%c1, %c1)
+// CHECK:               affine.apply {{.*}}([[OUTV1]], [[INV0B]])
+// CHECK:               affine.apply {{.*}}([[UNR_OUTV2]], [[INV1B]])
+// CHECK:               linalg.erf
+
+// -----
+
+func.func @negative_unroll_dynamic_parallel_loop(%arg0: memref<1x16x12xf32>, %arg1: memref<1x16x12xf32>, %ub3: index) {
+  %c12 = arith.constant 12 : index
+  %c16 = arith.constant 16 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c1, %c16, %ub3) step (%c1, %c1, %c1) {
+    %read = memref.load %arg0[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+    memref.store %read, %arg1[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+    scf.reduce
+  }
+  return
+}
+
+// CHECK-LABEL:   func @negative_unroll_dynamic_parallel_loop
+// CHECK-SAME:     ([[ARG0:%.*]]: memref<1x16x12xf32>, [[ARG1:%.*]]: memref<1x16x12xf32>, [[UB3:%.*]]: index)
+// CHECK:           scf.parallel ([[IV0:%.*]], [[IV1:%.*]], [[IV2:%.*]]) = (%c0, %c0, %c0) to (%c1, %c16, [[UB3]]) step (%c1, %c1, %c1)
+// CHECK:             [[LOADED:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK:             memref.store [[LOADED]], [[ARG1]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK-NOT:         affine.apply
+// CHECK-NOT:         memref.load
+// CHECK-NOT:         memref.store
+
+// -----
+
+func.func @unroll_inner_nested_parallel_loop(%arg0: memref<5x16x12x4x4xf32>, %arg2: memref<5x16x12x4x4xf32>) {
+  %c4 = arith.constant 4 : index
+  %c12 = arith.constant 12 : index
+  %c16 = arith.constant 16 : index
+  %c5 = arith.constant 5 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%c5, %c16, %c12) step (%c1, %c1, %c1) {
+    scf.parallel (%arg6, %arg7) = (%c0, %c0) to (%c4, %c4) step (%c1, %c1) {
+      %0 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg4, %arg6)
+      %1 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg5, %arg7)
+      %subv_in = memref.subview %arg0[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>>
+      %subv_out = memref.subview %arg2[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>>
+      linalg.erf ins(%subv_in : memref<4x4xf32, strided<[4, 1], offset: ?>>) outs(%subv_out : memref<4x4xf32, strided<[4, 1], offset: ?>>)
+      scf.reduce
+    }
+    scf.reduce
+  }
+  return
+}
+
+// CHECK-LABEL:                func @unroll_inner_nested_parallel_loop
+// CHECK-UNROLL-INNER-LABEL:   func @unroll_inner_nested_parallel_loop
+// CHECK-UNROLL-INNER-SAME:     ([[ARG0:%.*]]: memref<5x16x12x4x4xf32>, [[ARG1:%.*]]: memref<5x16x12x4x4xf32>)
+// CHECK-UNROLL-INNER:          scf.parallel ([[OUTV0:%.*]], [[OUTV1:%.*]], [[OUTV2:%.*]]) = (%c0, %c0, %c0) to (%c5, %c16, %c12) step (%c1, %c1, %c1)
+// CHECK-UNROLL-INNER:            scf.parallel ([[INV0:%.*]], [[INV1:%.*]]) = (%c0, %c0) to (%c4, %c4) step (%c1{{.*}}, %c2)
+// CHECK-UNROLL-INNER:              affine.apply {{.*}}([[OUTV1]], [[INV0]])
+// CHECK-UNROLL-INNER:              affine.apply {{.*}}([[OUTV2]], [[INV1]])
+// CHECK-UNROLL-INNER:              linalg.erf
+
+// CHECK-UNROLL-INNER:              [[UNR_INV1:%.*]] = affine.apply {{.*}}([[INV1]])
+// CHECK-UNROLL-INNER:              affine.apply {{.*}}([[OUTV1]], [[INV0]])
+// CHECK-UNROLL-INNER:              affine.apply {{.*}}([[OUTV2]], [[UNR_INV1]])
+// CHECK-UNROLL-INNER:              linalg.erf
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 791c2e681415a..d2f97e816cc14 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -2,6 +2,7 @@
 add_mlir_library(MLIRSCFTestPasses
   TestLoopParametricTiling.cpp
   TestLoopUnrolling.cpp
+  TestParallelLoopUnrolling.cpp
   TestSCFUtils.cpp
   TestSCFWrapInZeroTripCheck.cpp
   TestUpliftWhileToFor.cpp
diff --git a/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp
new file mode 100644
index 0000000000000..a32626843fa64
--- /dev/null
+++ b/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp
@@ -0,0 +1,86 @@
+//===-------- TestParallelLoopUnrolling.cpp --- loop unrolling test pass
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to unroll loops by a specified unroll factor.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+static unsigned getNestingDepth(Operation *op) {
+  Operation *currOp = op;
+  unsigned depth = 0;
+  while ((currOp = currOp->getParentOp())) {
+    if (isa<scf::ParallelOp>(currOp))
+      depth++;
+  }
+  return depth;
+}
+
+struct TestParallelLoopUnrollingPass
+    : public PassWrapper<TestParallelLoopUnrollingPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestParallelLoopUnrollingPass)
+
+  StringRef getArgument() const final { return "test-parallel-loop-unrolling"; }
+  StringRef getDescription() const final {
+    return "Tests parallel loop unrolling transformation";
+  }
+  TestParallelLoopUnrollingPass() = default;
+  TestParallelLoopUnrollingPass(const TestParallelLoopUnrollingPass &) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect>();
+  }
+
+  void runOnOperation() override {
+    SmallVector<scf::ParallelOp, 4> loops;
+    getOperation()->walk([&](scf::ParallelOp parLoop) {
+      if (getNestingDepth(parLoop) == loopDepth)
+        loops.push_back(parLoop);
+    });
+    auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) {
+      if (annotateLoop) {
+        op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
+      }
+    };
+    PatternRewriter rewriter(getOperation()->getContext());
+    for (auto loop : loops) {
+      (void)parallelLoopUnrollByFactors(loop, unrollFactors, rewriter,
+                                        annotateFn);
+    }
+  }
+
+  ListOption<uint64_t> unrollFactors{
+      *this, "unroll-factors",
+      llvm::cl::desc(
+          "Unroll factors for each parallel loop dim. If fewer factors than "
+          "loop dims are provided, they are applied to the inner dims.")};
+  Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
+                             llvm::cl::init(0)};
+  Option<bool> annotateLoop{*this, "annotate",
+                            llvm::cl::desc("Annotate unrolled iterations."),
+                            llvm::cl::init(false)};
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestParallelLoopUnrollingPass() {
+  PassRegistration<TestParallelLoopUnrollingPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e4620c009af8c..6ccd60fe8505c 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -139,6 +139,7 @@ void registerTestOneShotModuleBufferizePass();
 void registerTestOpaqueLoc();
 void registerTestOpLoweringPasses();
 void registerTestPadFusion();
+void registerTestParallelLoopUnrollingPass();
 void registerTestRecursiveTypesPass();
 void registerTestSCFUpliftWhileToFor();
 void registerTestSCFUtilsPass();
@@ -286,6 +287,7 @@ void registerTestPasses() {
   mlir::test::registerTestOpaqueLoc();
   mlir::test::registerTestOpLoweringPasses();
   mlir::test::registerTestPadFusion();
+  mlir::test::registerTestParallelLoopUnrollingPass();
   mlir::test::registerTestRecursiveTypesPass();
   mlir::test::registerTestSCFUpliftWhileToFor();
   mlir::test::registerTestSCFUtilsPass();



More information about the Mlir-commits mailing list