[Mlir-commits] [mlir] 9f151b7 - [mlir] AsyncParallelFor: sink constants into the parallel compute function
Eugene Zhulenev
llvmlistbot at llvm.org
Thu Dec 9 06:48:31 PST 2021
Author: Eugene Zhulenev
Date: 2021-12-09T06:48:23-08:00
New Revision: 9f151b784be0d0b27ad29cb24629815701d60481
URL: https://github.com/llvm/llvm-project/commit/9f151b784be0d0b27ad29cb24629815701d60481
DIFF: https://github.com/llvm/llvm-project/commit/9f151b784be0d0b27ad29cb24629815701d60481.diff
LOG: [mlir] AsyncParallelFor: sink constants into the parallel compute function
With complex recursive structure of async dispatch function LLVM can't always propagate constants to the parallel_compute_fn and it often prevents optimizations like loop unrolling and vectorization. We help LLVM by pushing known constants into the parallel_compute_fn explicitly.
Reviewed By: bkramer
Differential Revision: https://reviews.llvm.org/D115263
Added:
Modified:
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 8d80e6753115b..a4cab9d1ae6c4 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -120,16 +121,69 @@ struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
struct ParallelComputeFunctionType {
FunctionType type;
- llvm::SmallVector<Value> captures;
+ SmallVector<Value> captures;
+};
+
+// Helper struct to parse parallel compute function argument list.
+struct ParallelComputeFunctionArgs {
+ BlockArgument blockIndex();
+ BlockArgument blockSize();
+ ArrayRef<BlockArgument> tripCounts();
+ ArrayRef<BlockArgument> lowerBounds();
+ ArrayRef<BlockArgument> upperBounds();
+ ArrayRef<BlockArgument> steps();
+ ArrayRef<BlockArgument> captures();
+
+ unsigned numLoops;
+ ArrayRef<BlockArgument> args;
+};
+
+struct ParallelComputeFunctionBounds {
+ SmallVector<IntegerAttr> tripCounts;
+ SmallVector<IntegerAttr> lowerBounds;
+ SmallVector<IntegerAttr> upperBounds;
+ SmallVector<IntegerAttr> steps;
};
struct ParallelComputeFunction {
+ unsigned numLoops;
FuncOp func;
llvm::SmallVector<Value> captures;
};
} // namespace
+BlockArgument ParallelComputeFunctionArgs::blockIndex() { return args[0]; }
+BlockArgument ParallelComputeFunctionArgs::blockSize() { return args[1]; }
+
+ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() {
+ return args.drop_front(2).take_front(numLoops);
+}
+
+ArrayRef<BlockArgument> ParallelComputeFunctionArgs::lowerBounds() {
+ return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
+}
+
+ArrayRef<BlockArgument> ParallelComputeFunctionArgs::upperBounds() {
+ return args.drop_front(2 + 2 * numLoops).take_front(numLoops);
+}
+
+ArrayRef<BlockArgument> ParallelComputeFunctionArgs::steps() {
+ return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
+}
+
+ArrayRef<BlockArgument> ParallelComputeFunctionArgs::captures() {
+ return args.drop_front(2 + 4 * numLoops);
+}
+
+template <typename ValueRange>
+static SmallVector<IntegerAttr> integerConstants(ValueRange values) {
+ SmallVector<IntegerAttr> attrs(values.size());
+ for (unsigned i = 0; i < values.size(); ++i)
+ matchPattern(values[i], m_Constant(&attrs[i]));
+ return attrs;
+}
+
// Converts one-dimensional iteration index in the [0, tripCount) interval
// into multidimensional iteration coordinate.
static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
@@ -154,7 +208,7 @@ getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
llvm::SetVector<Value> captures;
getUsedValuesDefinedAbove(op.region(), op.region(), captures);
- llvm::SmallVector<Type> inputs;
+ SmallVector<Type> inputs;
inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
Type indexTy = rewriter.getIndexType();
@@ -167,7 +221,9 @@ getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
for (unsigned i = 0; i < op.getNumLoops(); ++i)
inputs.push_back(indexTy); // loop tripCount
- // Parallel operation lower bound, upper bound and step.
+ // Parallel operation lower bound, upper bound and step. Lower bound, upper
+ // bound and step passed as contiguous arguments:
+ // call @compute(%lb0, %lb1, ..., %ub0, %ub1, ..., %step0, %step1, ...)
for (unsigned i = 0; i < op.getNumLoops(); ++i) {
inputs.push_back(indexTy); // lower bound
inputs.push_back(indexTy); // upper bound
@@ -185,16 +241,14 @@ getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
// Create a parallel compute fuction from the parallel operation.
static ParallelComputeFunction
-createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
+createParallelComputeFunction(scf::ParallelOp op,
+ ParallelComputeFunctionBounds bounds,
+ PatternRewriter &rewriter) {
OpBuilder::InsertionGuard guard(rewriter);
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
ModuleOp module = op->getParentOfType<ModuleOp>();
- // Make sure that all constants will be inside the parallel operation body to
- // reduce the number of parallel compute function arguments.
- cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
-
ParallelComputeFunctionType computeFuncType =
getParallelComputeFunctionType(op, rewriter);
@@ -211,27 +265,35 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs());
b.setInsertionPointToEnd(block);
- unsigned offset = 0; // argument offset for arguments decoding
-
- // Returns `numArguments` arguments starting from `offset` and updates offset
- // by moving forward to the next argument.
- auto getArguments = [&](unsigned numArguments) -> ArrayRef<Value> {
- auto args = block->getArguments();
- auto slice = args.drop_front(offset).take_front(numArguments);
- offset += numArguments;
- return {slice.begin(), slice.end()};
- };
+ ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
// Block iteration position defined by the block index and size.
- Value blockIndex = block->getArgument(offset++);
- Value blockSize = block->getArgument(offset++);
+ BlockArgument blockIndex = args.blockIndex();
+ BlockArgument blockSize = args.blockSize();
// Constants used below.
Value c0 = b.create<arith::ConstantIndexOp>(0);
Value c1 = b.create<arith::ConstantIndexOp>(1);
+ // Materialize known constants as constant operation in the function body.
+ auto values = [&](ArrayRef<BlockArgument> args, ArrayRef<IntegerAttr> attrs) {
+ return llvm::to_vector(
+ llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value {
+ if (IntegerAttr attr = std::get<1>(tuple))
+ return b.create<ConstantOp>(attr);
+ return std::get<0>(tuple);
+ }));
+ };
+
// Multi-dimensional parallel iteration space defined by the loop trip counts.
- ArrayRef<Value> tripCounts = getArguments(op.getNumLoops());
+ auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
+
+ // Parallel operation lower bound and step.
+ auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
+ auto steps = values(args.steps(), bounds.steps);
+
+ // Remaining arguments are implicit captures of the parallel operation.
+ ArrayRef<BlockArgument> captures = args.captures();
// Compute a product of trip counts to get the size of the flattened
// one-dimensional iteration space.
@@ -239,14 +301,6 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
for (unsigned i = 1; i < tripCounts.size(); ++i)
tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
- // Parallel operation lower bound and step.
- ArrayRef<Value> lowerBound = getArguments(op.getNumLoops());
- offset += op.getNumLoops(); // skip upper bound arguments
- ArrayRef<Value> step = getArguments(op.getNumLoops());
-
- // Remaining arguments are implicit captures of the parallel operation.
- ArrayRef<Value> captures = getArguments(block->getNumArguments() - offset);
-
// Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
// blockFirstIndex = blockIndex * blockSize
Value blockFirstIndex = b.create<arith::MulIOp>(blockIndex, blockSize);
@@ -312,7 +366,7 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
// Compute induction variable for `loopIdx`.
computeBlockInductionVars[loopIdx] = nb.create<arith::AddIOp>(
- lowerBound[loopIdx], nb.create<arith::MulIOp>(iv, step[loopIdx]));
+ lowerBounds[loopIdx], nb.create<arith::MulIOp>(iv, steps[loopIdx]));
// Check if we are inside first or last iteration of the loop.
isBlockFirstCoord[loopIdx] = nb.create<arith::CmpIOp>(
@@ -359,7 +413,7 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
workLoopBuilder(0));
b.create<ReturnOp>(ValueRange());
- return {func, std::move(computeFuncType.captures)};
+ return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
}
// Creates recursive async dispatch function for the given parallel compute
@@ -640,6 +694,10 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ // Make sure that all constants will be inside the parallel operation body to
+ // reduce the number of parallel compute function arguments.
+ cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
+
// Compute trip count for each loop induction variable:
// tripCount = ceil_div(upperBound - lowerBound, step);
SmallVector<Value> tripCounts(op.getNumLoops());
@@ -647,8 +705,8 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
auto lb = op.lowerBound()[i];
auto ub = op.upperBound()[i];
auto step = op.step()[i];
- auto range = b.create<arith::SubIOp>(ub, lb);
- tripCounts[i] = b.create<arith::CeilDivSIOp>(range, step);
+ auto range = b.createOrFold<arith::SubIOp>(ub, lb);
+ tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
}
// Compute a product of trip counts to get the 1-dimensional iteration space
@@ -699,10 +757,22 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
+ // Collect statically known constants defining the loop nest in the parallel
+ // compute function. LLVM can't always push constants across the non-trivial
+ // async dispatch call graph, by providing these values explicitly we can
+ // choose to build more efficient loop nest, and rely on a better constant
+ // folding, loop unrolling and vectorization.
+ ParallelComputeFunctionBounds staticBounds = {
+ integerConstants(tripCounts),
+ integerConstants(op.lowerBound()),
+ integerConstants(op.upperBound()),
+ integerConstants(op.step()),
+ };
+
// Create a parallel compute function that takes a block id and computes the
// parallel operation body for a subset of iteration space.
ParallelComputeFunction parallelComputeFunction =
- createParallelComputeFunction(op, rewriter);
+ createParallelComputeFunction(op, staticBounds, rewriter);
// Dispatch parallel compute function using async recursive work splitting,
// or by submitting compute task sequentially from a caller thread.
diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index 5f0caf9e5ecf4..f95715d9eba06 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -1,8 +1,8 @@
-// RUN: mlir-opt %s \
+// RUN: mlir-opt %s -split-input-file \
// RUN: -async-parallel-for=async-dispatch=true \
// RUN: | FileCheck %s
-// RUN: mlir-opt %s \
+// RUN: mlir-opt %s -split-input-file \
// RUN: -async-parallel-for=async-dispatch=false \
// RUN: -canonicalize -inline -symbol-dce \
// RUN: | FileCheck %s
@@ -34,3 +34,35 @@ func @clone_constant(%arg0: memref<?xf32>, %lb: index, %ub: index, %st: index) {
// CHECK: %[[CST:.*]] = arith.constant 1.0{{.*}} : f32
// CHECK: scf.for
// CHECK: memref.store %[[CST]], %[[MEMREF]]
+
+// -----
+
+// Check that constant loop bound sunk into the parallel compute function.
+
+// CHECK-LABEL: func @sink_constant_step(
+func @sink_constant_step(%arg0: memref<?xf32>, %lb: index, %ub: index) {
+ %one = arith.constant 1.0 : f32
+ %st = arith.constant 123 : index
+
+ scf.parallel (%i) = (%lb) to (%ub) step (%st) {
+ memref.store %one, %arg0[%i] : memref<?xf32>
+ }
+
+ return
+}
+
+// CHECK-LABEL: func private @parallel_compute_fn(
+// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
+// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
+// CHECK-SAME: %[[TRIP_COUNT:arg[0-9]+]]: index,
+// CHECK-SAME: %[[LB:arg[0-9]+]]: index,
+// CHECK-SAME: %[[UB:arg[0-9]+]]: index,
+// CHECK-SAME: %[[STEP:arg[0-9]+]]: index,
+// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?xf32>
+// CHECK-SAME: ) {
+// CHECK: %[[CSTEP:.*]] = arith.constant 123 : index
+// CHECK-NOT: %[[STEP]]
+// CHECK: scf.for %[[I:arg[0-9]+]]
+// CHECK: %[[TMP:.*]] = arith.muli %[[I]], %[[CSTEP]]
+// CHECK: %[[IDX:.*]] = arith.addi %[[LB]], %[[TMP]]
+// CHECK: memref.store
\ No newline at end of file
More information about the Mlir-commits
mailing list