[Mlir-commits] [mlir] 34a164c - [mlir:Async] Submit accidentally omitted changes

Eugene Zhulenev llvmlistbot at llvm.org
Fri Jun 25 12:23:09 PDT 2021


Author: Eugene Zhulenev
Date: 2021-06-25T12:23:02-07:00
New Revision: 34a164c93857f609862d74455d6801fe482ead8a

URL: https://github.com/llvm/llvm-project/commit/34a164c93857f609862d74455d6801fe482ead8a
DIFF: https://github.com/llvm/llvm-project/commit/34a164c93857f609862d74455d6801fe482ead8a.diff

LOG: [mlir:Async] Submit accidentally omitted changes

Accidentally pushed old branches that did not include all the changes discussed in the PRs.

https://reviews.llvm.org/rGd43b23608ad664f02f56e965ca78916bde220950
https://reviews.llvm.org/rG86ad0af87054c3cccd68d32e103a6f1f6c6194c7

Differential Revision: https://reviews.llvm.org/D104943

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Async/Passes.h
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
    mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
    mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index 5c5dbe914aed5..5d0a7f66cf774 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -19,6 +19,10 @@ namespace mlir {
 
 std::unique_ptr<Pass> createAsyncParallelForPass();
 
+std::unique_ptr<Pass> createAsyncParallelForPass(bool asyncDispatch,
+                                                 int32_t numWorkerThreads,
+                                                 int32_t targetBlockSize);
+
 std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
 
 std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 0156ede1b9b65..7d921839cf287 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -596,7 +596,7 @@ class RuntimeCreateGroupOpLowering
   matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     TypeConverter *converter = getTypeConverter();
-    Type resultType = op->getResultTypes()[0];
+    Type resultType = op.getResult().getType();
 
     rewriter.replaceOpWithNewOp<CallOp>(
         op, kCreateGroup, converter->convertType(resultType), operands);

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 02541479fb24b..1d545a52f7152 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -90,6 +90,14 @@ namespace {
 struct AsyncParallelForPass
     : public AsyncParallelForBase<AsyncParallelForPass> {
   AsyncParallelForPass() = default;
+
+  AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
+                       int32_t targetBlockSize) {
+    this->asyncDispatch = asyncDispatch;
+    this->numWorkerThreads = numWorkerThreads;
+    this->targetBlockSize = targetBlockSize;
+  }
+
   void runOnOperation() override;
 };
 
@@ -127,7 +135,7 @@ struct ParallelComputeFunction {
 // Converts one-dimensional iteration index in the [0, tripCount) interval
 // into multidimensional iteration coordinate.
 static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
-                                      const SmallVector<Value> &tripCounts) {
+                                      ArrayRef<Value> tripCounts) {
   SmallVector<Value> coords(tripCounts.size());
   assert(!tripCounts.empty() && "tripCounts must be not empty");
 
@@ -184,7 +192,6 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
   ModuleOp module = op->getParentOfType<ModuleOp>();
-  b.setInsertionPointToStart(&module->getRegion(0).front());
 
   ParallelComputeFunctionType computeFuncType =
       getParallelComputeFunctionType(op, rewriter);
@@ -204,12 +211,13 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
 
   unsigned offset = 0; // argument offset for arguments decoding
 
-  // Load multiple arguments into values vector.
-  auto getArguments = [&](unsigned num_arguments) -> SmallVector<Value> {
-    SmallVector<Value> values(num_arguments);
-    for (unsigned i = 0; i < num_arguments; ++i)
-      values[i] = block->getArgument(offset++);
-    return values;
+  // Returns `numArguments` arguments starting from `offset` and updates offset
+  // by moving forward to the next argument.
+  auto getArguments = [&](unsigned numArguments) -> ArrayRef<Value> {
+    auto args = block->getArguments();
+    auto slice = args.drop_front(offset).take_front(numArguments);
+    offset += numArguments;
+    return {slice.begin(), slice.end()};
   };
 
   // Block iteration position defined by the block index and size.
@@ -217,11 +225,11 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
   Value blockSize = block->getArgument(offset++);
 
   // Constants used below.
-  Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
-  Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
+  Value c0 = b.create<ConstantIndexOp>(0);
+  Value c1 = b.create<ConstantIndexOp>(1);
 
   // Multi-dimensional parallel iteration space defined by the loop trip counts.
-  SmallVector<Value> tripCounts = getArguments(op.getNumLoops());
+  ArrayRef<Value> tripCounts = getArguments(op.getNumLoops());
 
   // Compute a product of trip counts to get the size of the flattened
   // one-dimensional iteration space.
@@ -229,35 +237,34 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
   for (unsigned i = 1; i < tripCounts.size(); ++i)
     tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
 
-  // Parallel operation lower bound, upper bound and step.
-  SmallVector<Value> lowerBound = getArguments(op.getNumLoops());
-  SmallVector<Value> upperBound = getArguments(op.getNumLoops());
-  SmallVector<Value> step = getArguments(op.getNumLoops());
+  // Parallel operation lower bound and step.
+  ArrayRef<Value> lowerBound = getArguments(op.getNumLoops());
+  offset += op.getNumLoops(); // skip upper bound arguments
+  ArrayRef<Value> step = getArguments(op.getNumLoops());
 
   // Remaining arguments are implicit captures of the parallel operation.
-  SmallVector<Value> captures = getArguments(block->getNumArguments() - offset);
+  ArrayRef<Value> captures = getArguments(block->getNumArguments() - offset);
 
   // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
   //   blockFirstIndex = blockIndex * blockSize
   Value blockFirstIndex = b.create<MulIOp>(blockIndex, blockSize);
 
   // The last one-dimensional index in the block defined by the `blockIndex`:
-  //   blockLastIndex = max((blockIndex + 1) * blockSize, tripCount) - 1
-  Value blockEnd0 = b.create<AddIOp>(blockIndex, c1);
-  Value blockEnd1 = b.create<MulIOp>(blockEnd0, blockSize);
-  Value blockEnd2 = b.create<CmpIOp>(CmpIPredicate::sge, blockEnd1, tripCount);
-  Value blockEnd3 = b.create<SelectOp>(blockEnd2, tripCount, blockEnd1);
-  Value blockLastIndex = b.create<SubIOp>(blockEnd3, c1);
+  //   blockLastIndex = max(blockFirstIndex + blockSize, tripCount) - 1
+  Value blockEnd0 = b.create<AddIOp>(blockFirstIndex, blockSize);
+  Value blockEnd1 = b.create<CmpIOp>(CmpIPredicate::sge, blockEnd0, tripCount);
+  Value blockEnd2 = b.create<SelectOp>(blockEnd1, tripCount, blockEnd0);
+  Value blockLastIndex = b.create<SubIOp>(blockEnd2, c1);
 
   // Convert one-dimensional indices to multi-dimensional coordinates.
   auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
   auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts);
 
-  // Compute compute loops upper bounds from the block last coordinates:
+  // Compute loops upper bounds derived from the block last coordinates:
   //   blockEndCoord[i] = blockLastCoord[i] + 1
   //
   // Block first and last coordinates can be the same along the outer compute
-  // dimension when inner compute dimension containts multple blocks.
+  // dimension when inner compute dimension contains multiple blocks.
   SmallVector<Value> blockEndCoord(op.getNumLoops());
   for (size_t i = 0; i < blockLastCoord.size(); ++i)
     blockEndCoord[i] = b.create<AddIOp>(blockLastCoord[i], c1);
@@ -312,7 +319,7 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
       isBlockLastCoord[loopIdx] =
           nb.create<CmpIOp>(CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
 
-      // Check if the previous loop is in its first of last iteration.
+      // Check if the previous loop is in its first or last iteration.
       if (loopIdx > 0) {
         isBlockFirstCoord[loopIdx] = nb.create<AndOp>(
             isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
@@ -380,7 +387,6 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
   ImplicitLocOpBuilder b(loc, rewriter);
 
   ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
-  b.setInsertionPointToStart(&module->getRegion(0).front());
 
   ArrayRef<Type> computeFuncInputTypes =
       computeFunc.func.type().cast<FunctionType>().getInputs();
@@ -408,8 +414,8 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
   b.setInsertionPointToEnd(block);
 
   Type indexTy = b.getIndexType();
-  Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
-  Value c2 = b.create<ConstantOp>(b.getIndexAttr(2));
+  Value c1 = b.create<ConstantIndexOp>(1);
+  Value c2 = b.create<ConstantIndexOp>(2);
 
   // Get the async group that will track async dispatch completion.
   Value group = block->getArgument(0);
@@ -439,14 +445,14 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
   }
 
   // Setup the async dispatch loop body: recursively call dispatch function
-  // for second the half of the original range and go to the next iteration.
+  // for the seconds half of the original range and go to the next iteration.
   {
     b.setInsertionPointToEnd(after);
     Value start = after->getArgument(0);
     Value end = after->getArgument(1);
     Value distance = b.create<SubIOp>(end, start);
     Value halfDistance = b.create<SignedDivIOp>(distance, c2);
-    Value midIndex = b.create<AddIOp>(after->getArgument(0), halfDistance);
+    Value midIndex = b.create<AddIOp>(start, halfDistance);
 
     // Call parallel compute function inside the async.execute region.
     auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
@@ -466,7 +472,7 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
     auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
                                        executeBodyBuilder);
     b.create<AddToGroupOp>(indexTy, execute.token(), group);
-    b.create<scf::YieldOp>(ValueRange({after->getArgument(0), midIndex}));
+    b.create<scf::YieldOp>(ValueRange({start, midIndex}));
   }
 
   // After dispatching async operations to process the tail of the block range
@@ -498,8 +504,8 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
   FuncOp asyncDispatchFunction =
       createAsyncDispatchFunction(parallelComputeFunction, rewriter);
 
-  Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
-  Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
+  Value c0 = b.create<ConstantIndexOp>(0);
+  Value c1 = b.create<ConstantIndexOp>(1);
 
   // Create an async.group to wait on all async tokens from the concurrent
   // execution of multiple parallel compute function. First block will be
@@ -535,8 +541,8 @@ doSequantialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
 
   FuncOp compute = parallelComputeFunction.func;
 
-  Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
-  Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
+  Value c0 = b.create<ConstantIndexOp>(0);
+  Value c1 = b.create<ConstantIndexOp>(1);
 
   // Create an async.group to wait on all async tokens from the concurrent
   // execution of multiple parallel compute function. First block will be
@@ -617,19 +623,16 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
   for (size_t i = 1; i < tripCounts.size(); ++i)
     tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
 
-  auto indexTy = b.getIndexType();
-
   // Do not overload worker threads with too many compute blocks.
-  Value maxComputeBlocks = b.create<ConstantOp>(
-      indexTy, b.getIndexAttr(numWorkerThreads * kMaxOversharding));
+  Value maxComputeBlocks =
+      b.create<ConstantIndexOp>(numWorkerThreads * kMaxOversharding);
 
   // Target block size from the pass parameters.
-  Value targetComputeBlockSize =
-      b.create<ConstantOp>(indexTy, b.getIndexAttr(targetBlockSize));
+  Value targetComputeBlockSize = b.create<ConstantIndexOp>(targetBlockSize);
 
   // Compute parallel block size from the parallel problem size:
   //   blockSize = min(tripCount,
-  //                   max(divup(tripCount, maxComputeBlocks),
+  //                   max(ceil_div(tripCount, maxComputeBlocks),
   //                       targetComputeBlockSize))
   Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
   Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize);
@@ -653,7 +656,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
                          blockCount, tripCounts);
   }
 
-  // Parallel operation was replaces with a block iteration loop.
+  // Parallel operation was replaced with a block iteration loop.
   rewriter.eraseOp(op);
 
   return success();
@@ -673,3 +676,10 @@ void AsyncParallelForPass::runOnOperation() {
 std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
   return std::make_unique<AsyncParallelForPass>();
 }
+
+std::unique_ptr<Pass>
+mlir::createAsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
+                                 int32_t targetBlockSize) {
+  return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
+                                                targetBlockSize);
+}

diff  --git a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
index 72b2e01045482..df538a4fc7661 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
@@ -18,18 +18,33 @@ func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
 // CHECK:         memref.store
 
 // CHECK-LABEL: func private @async_dispatch_fn
+// CHECK-SAME:  (
 // CHECK-SAME:    %[[GROUP:arg0]]: !async.group,
 // CHECK-SAME:    %[[BLOCK_START:arg1]]: index
 // CHECK-SAME:    %[[BLOCK_END:arg2]]: index
-
-// CHECK:         scf.while (%[[S:.*]] = %[[BLOCK_START]],
-// CHECK-SAME:               %[[E:.*]] = %[[BLOCK_END]])
+// CHECK-SAME:  )
+// CHECK:         %[[C1:.*]] = constant 1 : index
+// CHECK:         %[[C2:.*]] = constant 2 : index
+// CHECK:         scf.while (%[[S0:.*]] = %[[BLOCK_START]],
+// CHECK-SAME:               %[[E0:.*]] = %[[BLOCK_END]])
+// While loop `before` block decides if we need to dispatch more tasks.
+// CHECK:         {
+// CHECK:           %[[DIFF0:.*]] = subi %[[E0]], %[[S0]]
+// CHECK:           %[[COND:.*]] = cmpi sgt, %[[DIFF0]], %[[C1]]
+// CHECK:           scf.condition(%[[COND]])
+// While loop `after` block splits the range in half and submits async task
+// to process the second half using the call to the same dispatch function.
 // CHECK:         } do {
+// CHECK:         ^bb0(%[[S1:.*]]: index, %[[E1:.*]]: index):
+// CHECK:           %[[DIFF1:.*]] = subi %[[E1]], %[[S1]]
+// CHECK:           %[[HALF:.*]] = divi_signed %[[DIFF1]], %[[C2]]
+// CHECK:           %[[MID:.*]] = addi %[[S1]], %[[HALF]]
 // CHECK:           %[[TOKEN:.*]] = async.execute
 // CHECK:             call @async_dispatch_fn
-// CHECK:             async.add_to_group
+// CHECK:           async.add_to_group
+// CHECK:           scf.yield %[[S1]], %[[MID]]
 // CHECK:         }
-
+// After async dispatch the first block processed in the caller thread.
 // CHECK:         call @parallel_compute_fn(%[[BLOCK_START]]
 
 // -----

diff  --git a/mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir b/mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir
index 81ed85323e033..8fee953829fd5 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir
@@ -1,6 +1,9 @@
 // RUN: mlir-opt %s -split-input-file -async-parallel-for=async-dispatch=false  \
 // RUN: | FileCheck %s --dump-input=always
 
+// The structure of @parallel_compute_fn checked in the async dispatch test.
+// Here we only check the structure of the sequential dispatch loop.
+
 // CHECK-LABEL: @loop_1d
 func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
   // CHECK: %[[GROUP:.*]] = async.create_group


        


More information about the Mlir-commits mailing list