[Mlir-commits] [mlir] 49ce40e - [mlir] AsyncParallelFor: align block size to be a multiple of inner loops iterations

Eugene Zhulenev llvmlistbot at llvm.org
Thu Dec 9 06:50:57 PST 2021

Author: Eugene Zhulenev
Date: 2021-12-09T06:50:50-08:00
New Revision: 49ce40e9ab255588bf5093f468b30cbb6b99895b

URL: https://github.com/llvm/llvm-project/commit/49ce40e9ab255588bf5093f468b30cbb6b99895b
DIFF: https://github.com/llvm/llvm-project/commit/49ce40e9ab255588bf5093f468b30cbb6b99895b.diff

LOG: [mlir] AsyncParallelFor: align block size to be a multiple of inner loops iterations

Depends On D115263

By aligning block size to inner loop iterations parallel_compute_fn LLVM can later unroll and vectorize some of the inner loops with small number of trip counts. Up to 2x speedup in multiple benchmarks.

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D115436




diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index a4cab9d1ae6c4..d05e6bca19d11 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -240,10 +240,9 @@ getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
 // Create a parallel compute fuction from the parallel operation.
-static ParallelComputeFunction
-createParallelComputeFunction(scf::ParallelOp op,
-                              ParallelComputeFunctionBounds bounds,
-                              PatternRewriter &rewriter) {
+static ParallelComputeFunction createParallelComputeFunction(
+    scf::ParallelOp op, ParallelComputeFunctionBounds bounds,
+    unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter) {
   OpBuilder::InsertionGuard guard(rewriter);
   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
@@ -384,17 +383,26 @@ createParallelComputeFunction(scf::ParallelOp op,
       // Keep building loop nest.
       if (loopIdx < op.getNumLoops() - 1) {
-        // Select nested loop lower/upper bounds depending on our position in
-        // the multi-dimensional iteration space.
-        auto lb = nb.create<SelectOp>(isBlockFirstCoord[loopIdx],
-                                      blockFirstCoord[loopIdx + 1], c0);
+        if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
+          // For block aligned loops we always iterate starting from 0 up to
+          // the loop trip counts.
+          nb.create<scf::ForOp>(c0, tripCounts[loopIdx + 1], c1, ValueRange(),
+                                workLoopBuilder(loopIdx + 1));
+        } else {
+          // Select nested loop lower/upper bounds depending on our position in
+          // the multi-dimensional iteration space.
+          auto lb = nb.create<SelectOp>(isBlockFirstCoord[loopIdx],
+                                        blockFirstCoord[loopIdx + 1], c0);
+          auto ub = nb.create<SelectOp>(isBlockLastCoord[loopIdx],
+                                        blockEndCoord[loopIdx + 1],
+                                        tripCounts[loopIdx + 1]);
+          nb.create<scf::ForOp>(lb, ub, c1, ValueRange(),
+                                workLoopBuilder(loopIdx + 1));
+        }
-        auto ub = nb.create<SelectOp>(isBlockLastCoord[loopIdx],
-                                      blockEndCoord[loopIdx + 1],
-                                      tripCounts[loopIdx + 1]);
-        nb.create<scf::ForOp>(lb, ub, c1, ValueRange(),
-                              workLoopBuilder(loopIdx + 1));
@@ -731,6 +739,46 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
   auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
     ImplicitLocOpBuilder nb(loc, nestedBuilder);
+    // Collect statically known constants defining the loop nest in the parallel
+    // compute function. LLVM can't always push constants across the non-trivial
+    // async dispatch call graph, by providing these values explicitly we can
+    // choose to build more efficient loop nest, and rely on a better constant
+    // folding, loop unrolling and vectorization.
+    ParallelComputeFunctionBounds staticBounds = {
+        integerConstants(tripCounts),
+        integerConstants(op.lowerBound()),
+        integerConstants(op.upperBound()),
+        integerConstants(op.step()),
+    };
+    // Find how many inner iteration dimensions are statically known, and their
+    // product is smaller than the `512`. We aling the parallel compute block
+    // size by the product of statically known dimensions, so that we can
+    // guarantee that the inner loops executes from 0 to the loop trip counts
+    // and we can elide dynamic loop boundaries, and give LLVM an opportunity to
+    // unroll the loops. The constant `512` is arbitrary, it should depend on
+    // how many iterations LLVM will typically decide to unroll.
+    static constexpr int64_t maxIterations = 512;
+    // The number of inner loops with statically known number of iterations less
+    // than the `maxIterations` value.
+    int numUnrollableLoops = 0;
+    auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; };
+    SmallVector<int64_t> numIterations(op.getNumLoops());
+    numIterations.back() = getInt(staticBounds.tripCounts.back());
+    for (int i = op.getNumLoops() - 2; i >= 0; --i) {
+      int64_t tripCount = getInt(staticBounds.tripCounts[i]);
+      int64_t innerIterations = numIterations[i + 1];
+      numIterations[i] = tripCount * innerIterations;
+      // Update the number of inner loops that we can potentially unroll.
+      if (innerIterations > 0 && innerIterations <= maxIterations)
+        numUnrollableLoops++;
+    }
     // With large number of threads the value of creating many compute blocks
     // is reduced because the problem typically becomes memory bound. For small
     // number of threads it helps with stragglers.
@@ -755,24 +803,28 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
     Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
     Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSizeCst);
     Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
-    Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
-    // Collect statically known constants defining the loop nest in the parallel
-    // compute function. LLVM can't always push constants across the non-trivial
-    // async dispatch call graph, by providing these values explicitly we can
-    // choose to build more efficient loop nest, and rely on a better constant
-    // folding, loop unrolling and vectorization.
-    ParallelComputeFunctionBounds staticBounds = {
-        integerConstants(tripCounts),
-        integerConstants(op.lowerBound()),
-        integerConstants(op.upperBound()),
-        integerConstants(op.step()),
-    };
+    // Align the block size to be a multiple of the statically known number
+    // of iterations in the inner loops.
+    if (numUnrollableLoops > 0 && minTaskSize >= maxIterations) {
+      Value numIters = b.create<arith::ConstantIndexOp>(
+          numIterations[op.getNumLoops() - numUnrollableLoops]);
+      Value bs2 = b.create<arith::MulIOp>(
+          b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
+      blockSize = b.create<arith::MinSIOp>(tripCount, bs2);
+    } else {
+      // Reset the number of unrollable loops if we didn't align the block size.
+      numUnrollableLoops = 0;
+    }
+    // Compute the number of parallel compute blocks.
+    Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
-    // Create a parallel compute function that takes a block id and computes the
-    // parallel operation body for a subset of iteration space.
+    // Create a parallel compute function that takes a block id and computes
+    // the parallel operation body for a subset of iteration space.
     ParallelComputeFunction parallelComputeFunction =
-        createParallelComputeFunction(op, staticBounds, rewriter);
+        createParallelComputeFunction(op, staticBounds, numUnrollableLoops,
+                                      rewriter);
     // Dispatch parallel compute function using async recursive work splitting,
     // or by submitting compute task sequentially from a caller thread.

diff  --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index f95715d9eba06..591f770d9c6fe 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -65,4 +65,44 @@ func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
 // CHECK:        scf.for %[[I:arg[0-9]+]]
 // CHECK:          %[[TMP:.*]] = arith.muli %[[I]], %[[CSTEP]]
 // CHECK:          %[[IDX:.*]] = arith.addi %[[LB]], %[[TMP]]
-// CHECK:          memref.store
\ No newline at end of file
+// CHECK:          memref.store
+// -----
+// Check that for statically known inner loop bound block size is aligned and
+// inner loop uses statically known loop trip counts.
+// CHECK-LABEL: func @sink_constant_step(
+func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
+  %one = arith.constant 1.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  scf.parallel (%i, %j) = (%lb, %c0) to (%ub, %c10) step (%c1, %c1) {
+    memref.store %one, %arg0[%i, %j] : memref<?x10xf32>
+  }
+  return
+// CHECK-LABEL: func private @parallel_compute_fn(
+// CHECK-SAME:   %[[BLOCK_INDEX:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[BLOCK_SIZE:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[TRIP_COUNT0:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[TRIP_COUNT1:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[LB0:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[LB1:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[UB0:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[UB1:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[STEP0:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[STEP1:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
+// CHECK-SAME: ) {
+// CHECK:        %[[C0:.*]] = arith.constant 0 : index
+// CHECK:        %[[C1:.*]] = arith.constant 1 : index
+// CHECK:        %[[C10:.*]] = arith.constant 10 : index
+// CHECK:        scf.for %[[I:arg[0-9]+]]
+// CHECK-NOT:      select
+// CHECK:          scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1
\ No newline at end of file

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
index b92dfba6fef6f..444747369dd2a 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
@@ -31,7 +31,7 @@
 // RUN:   mlir-opt %s -async-parallel-for="async-dispatch=false                \
 // RUN:                                    num-workers=20                      \
-// RUN:                                    min-task-size=1"                \
+// RUN:                                    min-task-size=1"                    \
 // RUN:               -async-to-async-runtime                                  \
 // RUN:               -async-runtime-ref-counting                              \
 // RUN:               -async-runtime-ref-counting-opt                          \


More information about the Mlir-commits mailing list