[Mlir-commits] [mlir] 7bd87a0 - Promote readability by factoring out creation of min/max operation. Remove unnecessary divisions.
Eugene Zhulenev
llvmlistbot at llvm.org
Wed Nov 24 16:17:31 PST 2021
Author: bakhtiyar
Date: 2021-11-24T16:17:23-08:00
New Revision: 7bd87a03fdf1559569db1820abb21b6a479b0934
URL: https://github.com/llvm/llvm-project/commit/7bd87a03fdf1559569db1820abb21b6a479b0934
DIFF: https://github.com/llvm/llvm-project/commit/7bd87a03fdf1559569db1820abb21b6a479b0934.diff
LOG: Promote readability by factoring out creation of min/max operation. Remove unnecessary divisions.
Reviewed By: ezhulenev
Differential Revision: https://reviews.llvm.org/D110680
Added:
Modified:
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 77b8717145c6f..c9514a9da0851 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -252,12 +252,9 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
Value blockFirstIndex = b.create<arith::MulIOp>(blockIndex, blockSize);
// The last one-dimensional index in the block defined by the `blockIndex`:
- // blockLastIndex = max(blockFirstIndex + blockSize, tripCount) - 1
Value blockEnd0 = b.create<arith::AddIOp>(blockFirstIndex, blockSize);
- Value blockEnd1 =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, blockEnd0, tripCount);
- Value blockEnd2 = b.create<SelectOp>(blockEnd1, tripCount, blockEnd0);
- Value blockLastIndex = b.create<arith::SubIOp>(blockEnd2, c1);
+ Value blockEnd1 = b.create<arith::MinSIOp>(blockEnd0, tripCount);
+ Value blockLastIndex = b.create<arith::SubIOp>(blockEnd1, c1);
// Convert one-dimensional indices to multi-dimensional coordinates.
auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
@@ -696,17 +693,9 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
// blockSize = min(tripCount,
// max(ceil_div(tripCount, maxComputeBlocks),
// ceil_div(minTaskSize, bodySize)))
- Value bs0 = b.create<arith::DivSIOp>(tripCount, maxComputeBlocks);
- Value bs1 =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, bs0, minTaskSizeCst);
- Value bs2 = b.create<SelectOp>(bs1, bs0, minTaskSizeCst);
- Value bs3 =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, tripCount, bs2);
- Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
- Value blockCount0 = b.create<arith::CeilDivSIOp>(tripCount, blockSize0);
-
- // Compute balanced block size for the estimated block count.
- Value blockSize = b.create<arith::CeilDivSIOp>(tripCount, blockCount0);
+ Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
+ Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSizeCst);
+ Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
// Create a parallel compute function that takes a block id and computes the
diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
index 278bbf284d112..b9b08bbffecfd 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
@@ -2,6 +2,7 @@
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// RUN: -async-runtime-ref-counting-opt \
+// RUN: -arith-expand \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-memref-to-llvm \
@@ -16,6 +17,7 @@
// RUN: mlir-opt %s -async-parallel-for \
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-policy-based-ref-counting \
+// RUN: -arith-expand \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-memref-to-llvm \
@@ -33,6 +35,7 @@
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// RUN: -async-runtime-ref-counting-opt \
+// RUN: -arith-expand \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-memref-to-llvm \
More information about the Mlir-commits
mailing list