[Mlir-commits] [mlir] 2d45e33 - [MLIR][DISC] Revise ParallelLoopTilingPass with inbound_check mode

Stephan Herhut llvmlistbot at llvm.org
Mon Aug 16 05:03:32 PDT 2021


Author: tashuang.zk
Date: 2021-08-16T14:02:53+02:00
New Revision: 2d45e332ba321a22996a1584af26d568b375b674

URL: https://github.com/llvm/llvm-project/commit/2d45e332ba321a22996a1584af26d568b375b674
DIFF: https://github.com/llvm/llvm-project/commit/2d45e332ba321a22996a1584af26d568b375b674.diff

LOG: [MLIR][DISC] Revise ParallelLoopTilingPass with inbound_check mode

Expand ParallelLoopTilingPass with an inbound_check mode.

In default mode, the upper bound of the inner loop is from the min op; in
inbound_check mode, the upper bound of the inner loop is the step of the outer
loop and an additional inbound check will be emitted inside of the inner loop.

This was 'FIXME' in the original codes and a typical usage is for GPU backends,
thus the outer loop and inner loop can be mapped to blocks/threads in seperate.

Differential Revision: https://reviews.llvm.org/D105455

Added: 
    mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/Passes.h
    mlir/include/mlir/Dialect/SCF/Passes.td
    mlir/include/mlir/Dialect/SCF/Transforms.h
    mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h
index f8ed2c429b47f..df6a27232a0d3 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Passes.h
@@ -36,8 +36,13 @@ std::unique_ptr<Pass> createParallelLoopFusionPass();
 std::unique_ptr<Pass> createParallelLoopSpecializationPass();
 
 /// Creates a pass which tiles innermost parallel loops.
+/// If noMinMaxBounds, the upper bound of the inner loop will
+/// be a same value among 
diff erent outter loop iterations, and
+/// an additional inbound check will be emitted inside the internal
+/// loops.
 std::unique_ptr<Pass>
-createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {});
+createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
+                             bool noMinMaxBounds = false);
 
 /// Creates a pass which folds arith ops on induction variable into
 /// loop range.

diff  --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td
index 5e2a3a81bc0f0..44a7617ae4d97 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Passes.td
@@ -47,7 +47,11 @@ def SCFParallelLoopTiling : FunctionPass<"parallel-loop-tiling"> {
   let options = [
     ListOption<"tileSizes", "parallel-loop-tile-sizes", "int64_t",
                "Factors to tile parallel loops by",
-               "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
+               "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+    Option<"noMinMaxBounds", "no-min-max-bounds", "bool",
+           /*default=*/"false",
+           "Perform tiling with fixed upper bound with inbound check "
+           "inside the internal loops">
   ];
   let dependentDialects = ["AffineDialect"];
 }

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h
index 5cb816c808fa1..603fd3f0e9c30 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms.h
@@ -87,7 +87,8 @@ LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, scf::IfOp &ifOp);
 /// The function returns the resulting ParallelOps, i.e. {outer_loop_op,
 /// inner_loop_op}.
 std::pair<ParallelOp, ParallelOp>
-tileParallelLoop(ParallelOp op, llvm::ArrayRef<int64_t> tileSizes);
+tileParallelLoop(ParallelOp op, llvm::ArrayRef<int64_t> tileSizes,
+                 bool noMinMaxBounds);
 
 /// Populates patterns for SCF structural type conversions and sets up the
 /// provided ConversionTarget with the appropriate legality configuration for

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index 8282c0771f302..af8dcaf4b5c38 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -33,12 +33,25 @@ using namespace mlir::scf;
 ///                                          min(%arg5*tileSize[1], %arg3-%i1))
 ///                                      step (%arg4, %arg5)
 ///
+/// or, when no-min-max-bounds is true, into
+///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+///                                            step (%arg4*tileSize[0],
+///                                                  %arg5*tileSize[1])
+///     scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0],
+///                                          %arg5*tileSize[1])
+///                                      step (%arg4, %arg5)
+///        %inbound = (%j0 * %arg4 + %i0 < %arg2) &&
+///                   (%j1 * %arg5 + %i1 < %arg3)
+///        scf.if (%inbound)
+///          ....
+///
 /// where the uses of %i0 and %i1 in the loop body are replaced by
 /// %i0 + j0 and %i1 + %j1.
 //
 /// The old loop is replaced with the new one.
 std::pair<ParallelOp, ParallelOp>
-mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
+mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
+                            bool noMinMaxBounds) {
   OpBuilder b(op);
   auto zero = b.create<ConstantIndexOp>(op.getLoc(), 0);
   SmallVector<Value, 2> tileSizeConstants;
@@ -64,8 +77,6 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
   b.setInsertionPointToStart(outerLoop.getBody());
 
   // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
-  // FIXME: Instead of using min, we want to replicate the tail. This would give
-  // the inner loop constant bounds for easy vectorization.
   auto minMap = AffineMap::get(
       /*dimCount=*/3, /*symbolCount=*/0,
       {getAffineDimExpr(/*position=*/0, b.getContext()),
@@ -76,6 +87,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
   // Create the inner loop with adjusted bounds.
   SmallVector<Value, 2> newBounds;
   newBounds.reserve(op.upperBound().size());
+  bool needInboundCheck = false;
   for (auto dim : llvm::zip(outerLoop.lowerBound(), outerLoop.upperBound(),
                             outerLoop.step(), outerLoop.getInductionVars(),
                             op.step(), tileSizeConstants)) {
@@ -101,6 +113,14 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
         continue;
       }
     }
+
+    // For InboundCheck mode, just use the variable outer step
+    if (noMinMaxBounds) {
+      newBounds.push_back(newStep);
+      needInboundCheck = true;
+      continue;
+    }
+
     // Otherwise, we dynamically compute the bound for
     // each iteration of the outer loop.
     newBounds.push_back(
@@ -111,17 +131,51 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
       op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds,
       op.step());
 
-  // Steal the body of the old parallel loop and erase it.
-  innerLoop.region().takeBody(op.region());
-
-  // Insert computation for new index vectors and replace uses.
-  b.setInsertionPointToStart(innerLoop.getBody());
-  for (auto ivs :
-       llvm::zip(innerLoop.getInductionVars(), outerLoop.getInductionVars())) {
-    Value inner_index = std::get<0>(ivs);
-    AddIOp newIndex =
-        b.create<AddIOp>(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs));
-    inner_index.replaceAllUsesExcept(newIndex, newIndex);
+  if (noMinMaxBounds && needInboundCheck) {
+    b.setInsertionPointToStart(innerLoop.getBody());
+    // Insert in-bound check
+    Value inbound =
+        b.create<ConstantOp>(op.getLoc(), b.getIntegerType(1),
+                             b.getIntegerAttr(b.getIntegerType(1), 1));
+    for (auto dim :
+         llvm::zip(outerLoop.upperBound(), outerLoop.getInductionVars(),
+                   innerLoop.getInductionVars(), innerLoop.step())) {
+      Value outerUpperBound, outerIV, innerIV, innerStep;
+      std::tie(outerUpperBound, outerIV, innerIV, innerStep) = dim;
+      // %in_bound = %in_bound &&
+      //             (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound)
+      Value index = b.create<AddIOp>(
+          op.getLoc(), b.create<MulIOp>(op.getLoc(), innerIV, innerStep),
+          outerIV);
+      Value dimInbound = b.create<CmpIOp>(op.getLoc(), CmpIPredicate::ult,
+                                          index, outerUpperBound);
+      inbound = b.create<AndOp>(op.getLoc(), inbound, dimInbound);
+    }
+    auto ifInbound = b.create<IfOp>(op.getLoc(),
+                                    /*resultTypes*/ ArrayRef<Type>{}, inbound,
+                                    /*hasElseRegion*/ false);
+    ifInbound.thenRegion().takeBody(op.region());
+    Block &thenBlock = ifInbound.thenRegion().front();
+    b.setInsertionPointToStart(innerLoop.getBody());
+    for (auto ivs : llvm::enumerate(llvm::zip(innerLoop.getInductionVars(),
+                                              outerLoop.getInductionVars()))) {
+      AddIOp newIndex = b.create<AddIOp>(op.getLoc(), std::get<0>(ivs.value()),
+                                         std::get<1>(ivs.value()));
+      thenBlock.getArgument(ivs.index())
+          .replaceAllUsesExcept(newIndex, newIndex);
+    }
+    thenBlock.eraseArguments(llvm::to_vector<4>(
+        llvm::seq((unsigned)0, thenBlock.getNumArguments())));
+  } else {
+    innerLoop.region().takeBody(op.region());
+    b.setInsertionPointToStart(innerLoop.getBody());
+    for (auto ivs : llvm::zip(innerLoop.getInductionVars(),
+                              outerLoop.getInductionVars())) {
+      Value innerIndex = std::get<0>(ivs);
+      AddIOp newIndex =
+          b.create<AddIOp>(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs));
+      innerIndex.replaceAllUsesExcept(newIndex, newIndex);
+    }
   }
 
   op.erase();
@@ -132,8 +186,10 @@ namespace {
 struct ParallelLoopTiling
     : public SCFParallelLoopTilingBase<ParallelLoopTiling> {
   ParallelLoopTiling() = default;
-  explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes) {
+  explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes,
+                              bool noMinMaxBounds = false) {
     this->tileSizes = tileSizes;
+    this->noMinMaxBounds = noMinMaxBounds;
   }
 
   void runOnFunction() override {
@@ -142,13 +198,14 @@ struct ParallelLoopTiling
     for (ParallelOp ploop : innermostPloops) {
       // FIXME: Add reduction support.
       if (ploop.getNumReductions() == 0)
-        tileParallelLoop(ploop, tileSizes);
+        tileParallelLoop(ploop, tileSizes, noMinMaxBounds);
     }
   }
 };
 } // namespace
 
 std::unique_ptr<Pass>
-mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes) {
-  return std::make_unique<ParallelLoopTiling>(tileSizes);
+mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes,
+                                   bool noMinMaxBounds) {
+  return std::make_unique<ParallelLoopTiling>(tileSizes, noMinMaxBounds);
 }

diff  --git a/mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir b/mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir
new file mode 100644
index 0000000000000..8f395c3b2828a
--- /dev/null
+++ b/mlir/test/Dialect/SCF/parallel-loop-tiling-inbound-check.mlir
@@ -0,0 +1,149 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.func(parallel-loop-tiling{parallel-loop-tile-sizes=1,4 no-min-max-bounds=true})' -split-input-file | FileCheck %s
+
+func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
+                    %arg3 : index, %arg4 : index, %arg5 : index,
+		    %A: memref<?x?xf32>, %B: memref<?x?xf32>,
+                    %C: memref<?x?xf32>, %result: memref<?x?xf32>) {
+  scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
+    %B_elem = memref.load %B[%i0, %i1] : memref<?x?xf32>
+    %C_elem = memref.load %C[%i0, %i1] : memref<?x?xf32>
+    %sum_elem = addf %B_elem, %C_elem : f32
+    memref.store %sum_elem, %result[%i0, %i1] : memref<?x?xf32>
+  }
+  return
+}
+
+// CHECK-LABEL:   func @parallel_loop(
+// CHECK-SAME:                        [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index, [[ARG7:%.*]]: memref<?x?xf32>, [[ARG8:%.*]]: memref<?x?xf32>, [[ARG9:%.*]]: memref<?x?xf32>, [[ARG10:%.*]]: memref<?x?xf32>) {
+// CHECK:           [[C0:%.*]] = constant 0 : index
+// CHECK:           [[C1:%.*]] = constant 1 : index
+// CHECK:           [[C4:%.*]] = constant 4 : index
+// CHECK:           [[V1:%.*]] = muli [[ARG5]], [[C1]] : index
+// CHECK:           [[V2:%.*]] = muli [[ARG6]], [[C4]] : index
+// CHECK:           scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[ARG1]], [[ARG2]]) to ([[ARG3]], [[ARG4]]) step ([[V1]], [[V2]]) {
+// CHECK:             scf.parallel ([[V7:%.*]], [[V8:%.*]]) = ([[C0]], [[C0]]) to ([[V1]], [[V2]]) step ([[ARG5]], [[ARG6]]) {
+// CHECK:               [[V9:%.*]] = addi [[V7]], [[V3]] : index
+// CHECK:               [[V10:%.*]] = addi [[V8]], [[V4]] : index
+// CHECK:               %true = constant true
+// CHECK:               [[V11:%.*]] = muli [[V7]], [[ARG5]] : index
+// CHECK:               [[V12:%.*]] = addi [[V11]], [[V3]] : index
+// CHECK:               [[V13:%.*]] = cmpi ult, [[V12]], [[ARG3]] : index
+// CHECK:               [[V14:%.*]] = and %true, [[V13]] : i1
+// CHECK:               [[V15:%.*]] = muli [[V8]], [[ARG6]] : index
+// CHECK:               [[V16:%.*]] = addi [[V15]], [[V4]] : index
+// CHECK:               [[V17:%.*]] = cmpi ult, [[V16]], [[ARG4]] : index
+// CHECK:               [[V18:%.*]] = and [[V14]], [[V17]] : i1
+// CHECK:               scf.if [[V18]] {
+// CHECK:                 [[V19:%.*]] = memref.load [[ARG8]]{{\[}}[[V9]], [[V10]]] : memref<?x?xf32>
+// CHECK:                 [[V20:%.*]] = memref.load [[ARG9]]{{\[}}[[V9]], [[V10]]] : memref<?x?xf32>
+// CHECK:                 [[V21:%.*]] = addf [[V19]], [[V20]] : f32
+// CHECK:                 memref.store [[V21]], [[ARG10]]{{\[}}[[V9]], [[V10]]] : memref<?x?xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return
+
+// -----
+
+func @static_loop_with_step() {
+  %c0 = constant 0 : index
+  %c3 = constant 3 : index
+  %c22 = constant 22 : index
+  %c24 = constant 24 : index
+  scf.parallel (%i0, %i1) = (%c0, %c0) to (%c22, %c24) step (%c3, %c3) {
+  }
+  return
+}
+
+// CHECK-LABEL:   func @static_loop_with_step() {
+// CHECK:           [[C0:%.*]] = constant 0 : index
+// CHECK:           [[C3:%.*]] = constant 3 : index
+// CHECK:           [[C22:%.*]] = constant 22 : index
+// CHECK:           [[C24:%.*]] = constant 24 : index
+// CHECK:           [[C0_1:%.*]] = constant 0 : index
+// CHECK:           [[C1:%.*]] = constant 1 : index
+// CHECK:           [[C4:%.*]] = constant 4 : index
+// CHECK:           [[V1:%.*]] = muli [[C3]], [[C1]] : index
+// CHECK:           [[V2:%.*]] = muli [[C3]], [[C4]] : index
+// CHECK:           scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[C0]], [[C0]]) to ([[C22]], [[C24]]) step ([[V1]], [[V2]]) {
+// CHECK:             scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V1]], [[V2]]) step ([[C3]], [[C3]]) {
+// CHECK-NOT:           scf.if
+// CHECK:               = addi [[V5]], [[V3]] : index
+// CHECK:               = addi [[V6]], [[V4]] : index
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return
+
+// -----
+
+func @tile_nested_innermost() {
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    }
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+  }
+  return
+}
+
+// CHECK-LABEL:   func @tile_nested_innermost() {
+// CHECK:           [[C2:%.*]] = constant 2 : index
+// CHECK:           [[C0:%.*]] = constant 0 : index
+// CHECK:           [[C1:%.*]] = constant 1 : index
+// CHECK:           scf.parallel ([[V1:%.*]], [[V2:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:             [[C0_1:%.*]] = constant 0 : index
+// CHECK:             [[C1_1:%.*]] = constant 1 : index
+// CHECK:             [[C4:%.*]] = constant 4 : index
+// CHECK:             [[V3:%.*]] = muli [[C1]], [[C1_1]] : index
+// CHECK:             [[V4:%.*]] = muli [[C1]], [[C4]] : index
+// CHECK:             scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V3]], [[V4]]) {
+// CHECK:               scf.parallel ([[V8:%.*]], [[V9:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V3]], [[V4]]) step ([[C1]], [[C1]]) {
+// CHECK:                 = addi [[V8]], [[V5]] : index
+// CHECK:                 = addi [[V9]], [[V6]] : index
+// CHECK:                 scf.if
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:           [[C0_2:%.*]] = constant 0 : index
+// CHECK:           [[C1_2:%.*]] = constant 1 : index
+// CHECK:           [[C4_1:%.*]] = constant 4 : index
+// CHECK:           [[V10:%.*]] = muli [[C1]], [[C1_2]] : index
+// CHECK:           [[V11:%.*]] = muli [[C1]], [[C4_1]] : index
+// CHECK:           scf.parallel ([[V12:%.*]], [[V13:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V10]], [[V11]]) {
+// CHECK:             scf.parallel ([[V15:%.*]], [[V16:%.*]]) = ([[C0_2]], [[C0_2]]) to ([[V10]], [[V11]]) step ([[C1]], [[C1]]) {
+// CHECK:               = addi [[V15]], [[V12]] : index
+// CHECK:               = addi [[V16]], [[V13]] : index
+// CHECK:               scf.if
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func @tile_nested_in_non_ploop() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  scf.for %i = %c0 to %c2 step %c1 {
+    scf.for %j = %c0 to %c2 step %c1 {
+      scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+      }
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @tile_nested_in_non_ploop
+// CHECK:         scf.for
+// CHECK:           scf.for
+// CHECK:             scf.parallel
+// CHECK:               scf.parallel
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+// CHECK:       }


        


More information about the Mlir-commits mailing list