[Mlir-commits] [mlir] 9f151b7 - [mlir] AsyncParallelFor: sink constants into the parallel compute function

Eugene Zhulenev llvmlistbot at llvm.org
Thu Dec 9 06:48:31 PST 2021


Author: Eugene Zhulenev
Date: 2021-12-09T06:48:23-08:00
New Revision: 9f151b784be0d0b27ad29cb24629815701d60481

URL: https://github.com/llvm/llvm-project/commit/9f151b784be0d0b27ad29cb24629815701d60481
DIFF: https://github.com/llvm/llvm-project/commit/9f151b784be0d0b27ad29cb24629815701d60481.diff

LOG: [mlir] AsyncParallelFor: sink constants into the parallel compute function

With complex recursive structure of async dispatch function LLVM can't always propagate constants to the parallel_compute_fn and it often prevents optimizations like loop unrolling and vectorization. We help LLVM by pushing known constants into the parallel_compute_fn explicitly.

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D115263

Added: 
    

Modified: 
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
    mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 8d80e6753115b..a4cab9d1ae6c4 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -120,16 +121,69 @@ struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
 
 struct ParallelComputeFunctionType {
   FunctionType type;
-  llvm::SmallVector<Value> captures;
+  SmallVector<Value> captures;
+};
+
+// Helper struct to parse parallel compute function argument list.
+struct ParallelComputeFunctionArgs {
+  BlockArgument blockIndex();
+  BlockArgument blockSize();
+  ArrayRef<BlockArgument> tripCounts();
+  ArrayRef<BlockArgument> lowerBounds();
+  ArrayRef<BlockArgument> upperBounds();
+  ArrayRef<BlockArgument> steps();
+  ArrayRef<BlockArgument> captures();
+
+  unsigned numLoops;
+  ArrayRef<BlockArgument> args;
+};
+
+struct ParallelComputeFunctionBounds {
+  SmallVector<IntegerAttr> tripCounts;
+  SmallVector<IntegerAttr> lowerBounds;
+  SmallVector<IntegerAttr> upperBounds;
+  SmallVector<IntegerAttr> steps;
 };
 
 struct ParallelComputeFunction {
+  unsigned numLoops;
   FuncOp func;
   llvm::SmallVector<Value> captures;
 };
 
 } // namespace
 
+BlockArgument ParallelComputeFunctionArgs::blockIndex() { return args[0]; }
+BlockArgument ParallelComputeFunctionArgs::blockSize() { return args[1]; }
+
+ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() {
+  return args.drop_front(2).take_front(numLoops);
+}
+
+ArrayRef<BlockArgument> ParallelComputeFunctionArgs::lowerBounds() {
+  return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
+}
+
+ArrayRef<BlockArgument> ParallelComputeFunctionArgs::upperBounds() {
+  return args.drop_front(2 + 2 * numLoops).take_front(numLoops);
+}
+
+ArrayRef<BlockArgument> ParallelComputeFunctionArgs::steps() {
+  return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
+}
+
+ArrayRef<BlockArgument> ParallelComputeFunctionArgs::captures() {
+  return args.drop_front(2 + 4 * numLoops);
+}
+
+template <typename ValueRange>
+static SmallVector<IntegerAttr> integerConstants(ValueRange values) {
+  SmallVector<IntegerAttr> attrs(values.size());
+  for (unsigned i = 0; i < values.size(); ++i)
+    matchPattern(values[i], m_Constant(&attrs[i]));
+  return attrs;
+}
+
 // Converts one-dimensional iteration index in the [0, tripCount) interval
 // into multidimensional iteration coordinate.
 static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
@@ -154,7 +208,7 @@ getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
   llvm::SetVector<Value> captures;
   getUsedValuesDefinedAbove(op.region(), op.region(), captures);
 
-  llvm::SmallVector<Type> inputs;
+  SmallVector<Type> inputs;
   inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
 
   Type indexTy = rewriter.getIndexType();
@@ -167,7 +221,9 @@ getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
   for (unsigned i = 0; i < op.getNumLoops(); ++i)
     inputs.push_back(indexTy); // loop tripCount
 
-  // Parallel operation lower bound, upper bound and step.
+  // Parallel operation lower bound, upper bound and step. Lower bound, upper
+  // bound and step passed as contiguous arguments:
+  //   call @compute(%lb0, %lb1, ..., %ub0, %ub1, ..., %step0, %step1, ...)
   for (unsigned i = 0; i < op.getNumLoops(); ++i) {
     inputs.push_back(indexTy); // lower bound
     inputs.push_back(indexTy); // upper bound
@@ -185,16 +241,14 @@ getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
 
 // Create a parallel compute fuction from the parallel operation.
 static ParallelComputeFunction
-createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
+createParallelComputeFunction(scf::ParallelOp op,
+                              ParallelComputeFunctionBounds bounds,
+                              PatternRewriter &rewriter) {
   OpBuilder::InsertionGuard guard(rewriter);
   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
   ModuleOp module = op->getParentOfType<ModuleOp>();
 
-  // Make sure that all constants will be inside the parallel operation body to
-  // reduce the number of parallel compute function arguments.
-  cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
-
   ParallelComputeFunctionType computeFuncType =
       getParallelComputeFunctionType(op, rewriter);
 
@@ -211,27 +265,35 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
   Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs());
   b.setInsertionPointToEnd(block);
 
-  unsigned offset = 0; // argument offset for arguments decoding
-
-  // Returns `numArguments` arguments starting from `offset` and updates offset
-  // by moving forward to the next argument.
-  auto getArguments = [&](unsigned numArguments) -> ArrayRef<Value> {
-    auto args = block->getArguments();
-    auto slice = args.drop_front(offset).take_front(numArguments);
-    offset += numArguments;
-    return {slice.begin(), slice.end()};
-  };
+  ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
 
   // Block iteration position defined by the block index and size.
-  Value blockIndex = block->getArgument(offset++);
-  Value blockSize = block->getArgument(offset++);
+  BlockArgument blockIndex = args.blockIndex();
+  BlockArgument blockSize = args.blockSize();
 
   // Constants used below.
   Value c0 = b.create<arith::ConstantIndexOp>(0);
   Value c1 = b.create<arith::ConstantIndexOp>(1);
 
+  // Materialize known constants as constant operation in the function body.
+  auto values = [&](ArrayRef<BlockArgument> args, ArrayRef<IntegerAttr> attrs) {
+    return llvm::to_vector(
+        llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value {
+          if (IntegerAttr attr = std::get<1>(tuple))
+            return b.create<ConstantOp>(attr);
+          return std::get<0>(tuple);
+        }));
+  };
+
   // Multi-dimensional parallel iteration space defined by the loop trip counts.
-  ArrayRef<Value> tripCounts = getArguments(op.getNumLoops());
+  auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
+
+  // Parallel operation lower bound and step.
+  auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
+  auto steps = values(args.steps(), bounds.steps);
+
+  // Remaining arguments are implicit captures of the parallel operation.
+  ArrayRef<BlockArgument> captures = args.captures();
 
   // Compute a product of trip counts to get the size of the flattened
   // one-dimensional iteration space.
@@ -239,14 +301,6 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
   for (unsigned i = 1; i < tripCounts.size(); ++i)
     tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
 
-  // Parallel operation lower bound and step.
-  ArrayRef<Value> lowerBound = getArguments(op.getNumLoops());
-  offset += op.getNumLoops(); // skip upper bound arguments
-  ArrayRef<Value> step = getArguments(op.getNumLoops());
-
-  // Remaining arguments are implicit captures of the parallel operation.
-  ArrayRef<Value> captures = getArguments(block->getNumArguments() - offset);
-
   // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
   //   blockFirstIndex = blockIndex * blockSize
   Value blockFirstIndex = b.create<arith::MulIOp>(blockIndex, blockSize);
@@ -312,7 +366,7 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
 
       // Compute induction variable for `loopIdx`.
       computeBlockInductionVars[loopIdx] = nb.create<arith::AddIOp>(
-          lowerBound[loopIdx], nb.create<arith::MulIOp>(iv, step[loopIdx]));
+          lowerBounds[loopIdx], nb.create<arith::MulIOp>(iv, steps[loopIdx]));
 
       // Check if we are inside first or last iteration of the loop.
       isBlockFirstCoord[loopIdx] = nb.create<arith::CmpIOp>(
@@ -359,7 +413,7 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
                        workLoopBuilder(0));
   b.create<ReturnOp>(ValueRange());
 
-  return {func, std::move(computeFuncType.captures)};
+  return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
 }
 
 // Creates recursive async dispatch function for the given parallel compute
@@ -640,6 +694,10 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
 
   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
+  // Make sure that all constants will be inside the parallel operation body to
+  // reduce the number of parallel compute function arguments.
+  cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
+
   // Compute trip count for each loop induction variable:
   //   tripCount = ceil_div(upperBound - lowerBound, step);
   SmallVector<Value> tripCounts(op.getNumLoops());
@@ -647,8 +705,8 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
     auto lb = op.lowerBound()[i];
     auto ub = op.upperBound()[i];
     auto step = op.step()[i];
-    auto range = b.create<arith::SubIOp>(ub, lb);
-    tripCounts[i] = b.create<arith::CeilDivSIOp>(range, step);
+    auto range = b.createOrFold<arith::SubIOp>(ub, lb);
+    tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
   }
 
   // Compute a product of trip counts to get the 1-dimensional iteration space
@@ -699,10 +757,22 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
     Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
     Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
 
+    // Collect statically known constants defining the loop nest in the parallel
+    // compute function. LLVM can't always push constants across the non-trivial
+    // async dispatch call graph, by providing these values explicitly we can
+    // choose to build more efficient loop nest, and rely on a better constant
+    // folding, loop unrolling and vectorization.
+    ParallelComputeFunctionBounds staticBounds = {
+        integerConstants(tripCounts),
+        integerConstants(op.lowerBound()),
+        integerConstants(op.upperBound()),
+        integerConstants(op.step()),
+    };
+
     // Create a parallel compute function that takes a block id and computes the
     // parallel operation body for a subset of iteration space.
     ParallelComputeFunction parallelComputeFunction =
-        createParallelComputeFunction(op, rewriter);
+        createParallelComputeFunction(op, staticBounds, rewriter);
 
     // Dispatch parallel compute function using async recursive work splitting,
     // or by submitting compute task sequentially from a caller thread.

diff  --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index 5f0caf9e5ecf4..f95715d9eba06 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -1,8 +1,8 @@
-// RUN: mlir-opt %s                                                            \
+// RUN: mlir-opt %s -split-input-file                                          \
 // RUN:    -async-parallel-for=async-dispatch=true                             \
 // RUN: | FileCheck %s
 
-// RUN: mlir-opt %s                                                            \
+// RUN: mlir-opt %s -split-input-file                                          \
 // RUN:    -async-parallel-for=async-dispatch=false                            \
 // RUN:    -canonicalize -inline -symbol-dce                                   \
 // RUN: | FileCheck %s
@@ -34,3 +34,35 @@ func @clone_constant(%arg0: memref<?xf32>, %lb: index, %ub: index, %st: index) {
 // CHECK:        %[[CST:.*]] = arith.constant 1.0{{.*}} : f32
 // CHECK:        scf.for
 // CHECK:          memref.store %[[CST]], %[[MEMREF]]
+
+// -----
+
+// Check that constant loop bound sunk into the parallel compute function.
+
+// CHECK-LABEL: func @sink_constant_step(
+func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
+  %one = arith.constant 1.0 : f32
+  %st = arith.constant 123 : index
+
+  scf.parallel (%i) = (%lb) to (%ub) step (%st) {
+    memref.store %one, %arg0[%i] : memref<?xf32>
+  }
+
+  return
+}
+
+// CHECK-LABEL: func private @parallel_compute_fn(
+// CHECK-SAME:   %[[BLOCK_INDEX:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[BLOCK_SIZE:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[TRIP_COUNT:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[LB:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[UB:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[STEP:arg[0-9]+]]: index,
+// CHECK-SAME:   %[[MEMREF:arg[0-9]+]]: memref<?xf32>
+// CHECK-SAME: ) {
+// CHECK:        %[[CSTEP:.*]] = arith.constant 123 : index
+// CHECK-NOT:    %[[STEP]]
+// CHECK:        scf.for %[[I:arg[0-9]+]]
+// CHECK:          %[[TMP:.*]] = arith.muli %[[I]], %[[CSTEP]]
+// CHECK:          %[[IDX:.*]] = arith.addi %[[LB]], %[[TMP]]
+// CHECK:          memref.store
\ No newline at end of file


        


More information about the Mlir-commits mailing list