[Mlir-commits] [mlir] 7bd87a0 - Promote readability by factoring out creation of min/max operation. Remove unnecessary divisions.

Eugene Zhulenev llvmlistbot at llvm.org
Wed Nov 24 16:17:31 PST 2021


Author: bakhtiyar
Date: 2021-11-24T16:17:23-08:00
New Revision: 7bd87a03fdf1559569db1820abb21b6a479b0934

URL: https://github.com/llvm/llvm-project/commit/7bd87a03fdf1559569db1820abb21b6a479b0934
DIFF: https://github.com/llvm/llvm-project/commit/7bd87a03fdf1559569db1820abb21b6a479b0934.diff

LOG: Promote readability by factoring out creation of min/max operation. Remove unnecessary divisions.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D110680

Added: 
    

Modified: 
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
    mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 77b8717145c6f..c9514a9da0851 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -252,12 +252,9 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
   Value blockFirstIndex = b.create<arith::MulIOp>(blockIndex, blockSize);
 
   // The last one-dimensional index in the block defined by the `blockIndex`:
-  //   blockLastIndex = max(blockFirstIndex + blockSize, tripCount) - 1
   Value blockEnd0 = b.create<arith::AddIOp>(blockFirstIndex, blockSize);
-  Value blockEnd1 =
-      b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, blockEnd0, tripCount);
-  Value blockEnd2 = b.create<SelectOp>(blockEnd1, tripCount, blockEnd0);
-  Value blockLastIndex = b.create<arith::SubIOp>(blockEnd2, c1);
+  Value blockEnd1 = b.create<arith::MinSIOp>(blockEnd0, tripCount);
+  Value blockLastIndex = b.create<arith::SubIOp>(blockEnd1, c1);
 
   // Convert one-dimensional indices to multi-dimensional coordinates.
   auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
@@ -696,17 +693,9 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
     //   blockSize = min(tripCount,
     //                   max(ceil_div(tripCount, maxComputeBlocks),
     //                       ceil_div(minTaskSize, bodySize)))
-    Value bs0 = b.create<arith::DivSIOp>(tripCount, maxComputeBlocks);
-    Value bs1 =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, bs0, minTaskSizeCst);
-    Value bs2 = b.create<SelectOp>(bs1, bs0, minTaskSizeCst);
-    Value bs3 =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, tripCount, bs2);
-    Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
-    Value blockCount0 = b.create<arith::CeilDivSIOp>(tripCount, blockSize0);
-
-    // Compute balanced block size for the estimated block count.
-    Value blockSize = b.create<arith::CeilDivSIOp>(tripCount, blockCount0);
+    Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
+    Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSizeCst);
+    Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
     Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
 
     // Create a parallel compute function that takes a block id and computes the

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
index 278bbf284d112..b9b08bbffecfd 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
@@ -2,6 +2,7 @@
 // RUN:               -async-to-async-runtime                                  \
 // RUN:               -async-runtime-ref-counting                              \
 // RUN:               -async-runtime-ref-counting-opt                          \
+// RUN:               -arith-expand                                            \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-memref-to-llvm                                  \
@@ -16,6 +17,7 @@
 // RUN:   mlir-opt %s -async-parallel-for                                      \
 // RUN:               -async-to-async-runtime                                  \
 // RUN:               -async-runtime-policy-based-ref-counting                 \
+// RUN:               -arith-expand                                            \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-memref-to-llvm                                  \
@@ -33,6 +35,7 @@
 // RUN:               -async-to-async-runtime                                  \
 // RUN:               -async-runtime-ref-counting                              \
 // RUN:               -async-runtime-ref-counting-opt                          \
+// RUN:               -arith-expand                                            \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-memref-to-llvm                                  \


        


More information about the Mlir-commits mailing list