[Mlir-commits] [mlir] 1e60678 - [MLIR] Fix parallel loop tiling.

Stephan Herhut llvmlistbot at llvm.org
Wed Jun 17 14:30:38 PDT 2020


Author: Stephan Herhut
Date: 2020-06-17T23:30:13+02:00
New Revision: 1e60678c1f68a9ba109a669afa471834692ce979

URL: https://github.com/llvm/llvm-project/commit/1e60678c1f68a9ba109a669afa471834692ce979
DIFF: https://github.com/llvm/llvm-project/commit/1e60678c1f68a9ba109a669afa471834692ce979.diff

LOG: [MLIR] Fix parallel loop tiling.

Summary:
Parallel loop tiling did not properly compute the updated loop
indices when tiling, which lead to wrong results.

Differential Revision: https://reviews.llvm.org/D82013

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
    mlir/test/Dialect/SCF/parallel-loop-tiling.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index 8e84566659f8..40469138ea01 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -30,9 +30,13 @@ using namespace mlir::scf;
 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
 ///                                             step (%arg4*tileSize[0],
 ///                                                   %arg5*tileSize[1])
-///     scf.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%j0)
-///                                           min(tileSize[1], %arg3-%j1))
+///     scf.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%i0)
+///                                           min(tileSize[1], %arg3-%i1))
 ///                                        step (%arg4, %arg5)
+///
+/// where the uses of %i0 and %i1 in the loop body are replaced by
+/// %i0 + j0 and %i1 + %j1.
+//
 /// The old loop is replaced with the new one.
 void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
   OpBuilder b(op);
@@ -85,6 +89,18 @@ void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
 
   // Steal the body of the old parallel loop and erase it.
   innerLoop.region().takeBody(op.region());
+
+  // Insert computation for new index vectors and replace uses.
+  b.setInsertionPointToStart(innerLoop.getBody());
+  for (auto ivs :
+       llvm::zip(innerLoop.getInductionVars(), outerLoop.getInductionVars())) {
+    Value inner_index = std::get<0>(ivs);
+    AddIOp newIndex =
+        b.create<AddIOp>(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs));
+    inner_index.replaceAllUsesExcept(
+        newIndex, SmallPtrSet<Operation *, 1>{newIndex.getOperation()});
+  }
+
   op.erase();
 }
 

diff  --git a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
index 14912436f96b..f12416266ed9 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
@@ -25,10 +25,12 @@ func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
 // CHECK:             [[VAL_17:%.*]] = affine.min #map0([[VAL_11]], [[VAL_2]], [[VAL_15]])
 // CHECK:             [[VAL_18:%.*]] = affine.min #map0([[VAL_12]], [[VAL_3]], [[VAL_16]])
 // CHECK:             scf.parallel ([[VAL_19:%.*]], [[VAL_20:%.*]]) = ([[VAL_10]], [[VAL_10]]) to ([[VAL_17]], [[VAL_18]]) step ([[VAL_4]], [[VAL_5]]) {
-// CHECK:               [[VAL_21:%.*]] = load [[VAL_7]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
-// CHECK:               [[VAL_22:%.*]] = load [[VAL_8]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
-// CHECK:               [[VAL_23:%.*]] = addf [[VAL_21]], [[VAL_22]] : f32
-// CHECK:               store [[VAL_23]], [[VAL_9]]{{\[}}[[VAL_19]], [[VAL_20]]] : memref<?x?xf32>
+// CHECK:               [[VAL_21:%.*]] = addi [[VAL_19]], [[VAL_15]] : index
+// CHECK:               [[VAL_22:%.*]] = addi [[VAL_20]], [[VAL_16]] : index
+// CHECK:               [[VAL_23:%.*]] = load [[VAL_7]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref<?x?xf32>
+// CHECK:               [[VAL_24:%.*]] = load [[VAL_8]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref<?x?xf32>
+// CHECK:               [[VAL_25:%.*]] = addf [[VAL_23]], [[VAL_24]] : f32
+// CHECK:               store [[VAL_25]], [[VAL_9]]{{\[}}[[VAL_21]], [[VAL_22]]] : memref<?x?xf32>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           return


        


More information about the Mlir-commits mailing list