[Mlir-commits] [mlir] 3a41ff4 - [mlir][SCF] Peel scf.for loops for even step divison
Matthias Springer
llvmlistbot at llvm.org
Mon Aug 2 18:34:22 PDT 2021
Author: Matthias Springer
Date: 2021-08-03T10:21:38+09:00
New Revision: 3a41ff4883fe8b9e34a4f30aa9eecaf2ecb2ef44
URL: https://github.com/llvm/llvm-project/commit/3a41ff4883fe8b9e34a4f30aa9eecaf2ecb2ef44
DIFF: https://github.com/llvm/llvm-project/commit/3a41ff4883fe8b9e34a4f30aa9eecaf2ecb2ef44.diff
LOG: [mlir][SCF] Peel scf.for loops for even step divison
Add ForLoopBoundSpecialization pass, which specializes scf.for loops into a "main loop" where `step` divides the iteration space evenly and into an scf.if that handles the last iteration.
This transformation is useful for vectorization and loop tiling. E.g., when vectorizing loads/stores, programs will spend most of their time in the main loop, in which only unmasked loads/stores are used. Only the in the last iteration (scf.if), slower masked loads/stores are used.
Subsequent commits will apply this transformation in the SparseDialect and in Linalg's loop tiling.
Differential Revision: https://reviews.llvm.org/D105804
Added:
mlir/test/Dialect/SCF/for-loop-peeling.mlir
Modified:
mlir/include/mlir/Dialect/SCF/Passes.h
mlir/include/mlir/Dialect/SCF/Passes.td
mlir/include/mlir/Dialect/SCF/Transforms.h
mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h
index 2d1f8b5aff051..f8ed2c429b47f 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Passes.h
@@ -24,6 +24,10 @@ std::unique_ptr<Pass> createSCFBufferizePass();
/// vectorization.
std::unique_ptr<Pass> createForLoopSpecializationPass();
+/// Creates a pass that peels for loops at their upper bounds for
+/// better vectorization.
+std::unique_ptr<Pass> createForLoopPeelingPass();
+
/// Creates a loop fusion pass which fuses parallel loops.
std::unique_ptr<Pass> createParallelLoopFusionPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td
index 172fb63206809..5e2a3a81bc0f0 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Passes.td
@@ -17,6 +17,13 @@ def SCFBufferize : FunctionPass<"scf-bufferize"> {
let dependentDialects = ["memref::MemRefDialect"];
}
+def SCFForLoopPeeling
+ : FunctionPass<"for-loop-peeling"> {
+ let summary = "Peel `for` loops at their upper bounds.";
+ let constructor = "mlir::createForLoopPeelingPass()";
+ let dependentDialects = ["AffineDialect"];
+}
+
def SCFForLoopSpecialization
: FunctionPass<"for-loop-specialization"> {
let summary = "Specialize `for` loops for vectorization";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h
index e1b0881d4af04..5cb816c808fa1 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms.h
@@ -18,8 +18,10 @@
namespace mlir {
class ConversionTarget;
+struct LogicalResult;
class MLIRContext;
class Region;
+class RewriterBase;
class TypeConverter;
class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
@@ -27,6 +29,8 @@ class Operation;
namespace scf {
+class IfOp;
+class ForOp;
class ParallelOp;
class ForOp;
@@ -35,6 +39,38 @@ class ForOp;
/// analysis.
void naivelyFuseParallelOps(Region ®ion);
+/// Rewrite a for loop with bounds/step that potentially do not divide evenly
+/// into a for loop where the step divides the iteration space evenly, followed
+/// by an scf.if for the last (partial) iteration (if any). This transformation
+/// is called "loop peeling".
+///
+/// Other patterns can simplify/canonicalize operations in the body of the loop
+/// and the scf.if. This is beneficial for a wide range of transformations such
+/// as vectorization or loop tiling.
+///
+/// E.g., assuming a lower bound of 0 (for illustration purposes):
+/// ```
+/// scf.for %iv = %c0 to %ub step %c4 {
+/// (loop body)
+/// }
+/// ```
+/// is rewritten into the following pseudo IR:
+/// ```
+/// %newUb = %ub - (%ub mod %c4)
+/// scf.for %iv = %c0 to %newUb step %c4 {
+/// (loop body)
+/// }
+/// scf.if %newUb < %ub {
+/// (loop body)
+/// }
+/// ```
+///
+/// This function rewrites the given scf.for loop in-place and creates a new
+/// scf.if operation (returned via `ifOp`) for the last iteration.
+///
+/// TODO: Simplify affine.min ops inside the new loop/if statement.
+LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, scf::IfOp &ifOp);
+
/// Tile a parallel loop of the form
/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
/// step (%arg4, %arg5)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 54c663ca67bca..f086f15ac7444 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -15,9 +15,14 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
using namespace mlir;
using scf::ForOp;
@@ -89,6 +94,89 @@ static void specializeForLoopForUnrolling(ForOp op) {
op.erase();
}
+/// Rewrite a for loop with bounds/step that potentially do not divide evenly
+/// into a for loop where the step divides the iteration space evenly, followed
+/// by an scf.if for the last (partial) iteration (if any).
+LogicalResult mlir::scf::peelForLoop(RewriterBase &b, ForOp forOp,
+ scf::IfOp &ifOp) {
+ RewriterBase::InsertionGuard guard(b);
+ auto lbInt = getConstantIntValue(forOp.lowerBound());
+ auto ubInt = getConstantIntValue(forOp.upperBound());
+ auto stepInt = getConstantIntValue(forOp.step());
+
+ // No specialization necessary if step already divides upper bound evenly.
+ if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
+ return failure();
+ // No specialization necessary if step size is 1.
+ if (stepInt == static_cast<int64_t>(1))
+ return failure();
+
+ auto loc = forOp.getLoc();
+ AffineExpr dim0, dim1, dim2;
+ bindDims(b.getContext(), dim0, dim1, dim2);
+ // New upper bound: %ub - (%ub - %lb) mod %step
+ auto modMap = AffineMap::get(3, 0, {dim1 - ((dim1 - dim0) % dim2)});
+ Value splitBound = b.createOrFold<AffineApplyOp>(
+ loc, modMap,
+ ValueRange{forOp.lowerBound(), forOp.upperBound(), forOp.step()});
+
+ // Set new upper loop bound.
+ Value previousUb = forOp.upperBound();
+ b.updateRootInPlace(forOp,
+ [&]() { forOp.upperBoundMutable().assign(splitBound); });
+ b.setInsertionPointAfter(forOp);
+
+ // Do we need one more iteration?
+ Value hasMoreIter =
+ b.create<CmpIOp>(loc, CmpIPredicate::slt, splitBound, previousUb);
+
+ // Create IfOp for last iteration.
+ auto resultTypes = llvm::to_vector<4>(
+ llvm::map_range(forOp.initArgs(), [](Value v) { return v.getType(); }));
+ ifOp = b.create<scf::IfOp>(loc, resultTypes, hasMoreIter,
+ /*withElseRegion=*/!resultTypes.empty());
+ forOp.replaceAllUsesWith(ifOp->getResults());
+
+ // Build then case.
+ BlockAndValueMapping bvm;
+ bvm.map(forOp.region().getArgument(0), splitBound);
+ for (auto it : llvm::zip(forOp.region().getArguments().drop_front(),
+ forOp->getResults())) {
+ bvm.map(std::get<0>(it), std::get<1>(it));
+ }
+ b.cloneRegionBefore(forOp.region(), ifOp.thenRegion(),
+ ifOp.thenRegion().begin(), bvm);
+ // Build else case.
+ if (!resultTypes.empty())
+ ifOp.getElseBodyBuilder().create<scf::YieldOp>(loc, forOp->getResults());
+
+ return success();
+}
+
+static constexpr char kPeeledLoopLabel[] = "__peeled_loop__";
+
+namespace {
+struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
+ using OpRewritePattern<ForOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const override {
+ if (forOp->hasAttr(kPeeledLoopLabel))
+ return failure();
+
+ scf::IfOp ifOp;
+ if (failed(peelForLoop(rewriter, forOp, ifOp)))
+ return failure();
+ // Apply label, so that the same loop is not rewritten a second time.
+ rewriter.updateRootInPlace(forOp, [&]() {
+ forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
+ });
+
+ return success();
+ }
+};
+} // namespace
+
namespace {
struct ParallelLoopSpecialization
: public SCFParallelLoopSpecializationBase<ParallelLoopSpecialization> {
@@ -104,6 +192,19 @@ struct ForLoopSpecialization
getFunction().walk([](ForOp op) { specializeForLoopForUnrolling(op); });
}
};
+
+struct ForLoopPeeling : public SCFForLoopPeelingBase<ForLoopPeeling> {
+ void runOnFunction() override {
+ FuncOp funcOp = getFunction();
+ MLIRContext *ctx = funcOp.getContext();
+ RewritePatternSet patterns(ctx);
+ patterns.add<ForLoopPeelingPattern>(ctx);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+
+ // Drop the marker.
+ funcOp.walk([](ForOp op) { op->removeAttr(kPeeledLoopLabel); });
+ }
+};
} // namespace
std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
@@ -113,3 +214,7 @@ std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
return std::make_unique<ForLoopSpecialization>();
}
+
+std::unique_ptr<Pass> mlir::createForLoopPeelingPass() {
+ return std::make_unique<ForLoopPeeling>();
+}
diff --git a/mlir/test/Dialect/SCF/for-loop-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-peeling.mlir
new file mode 100644
index 0000000000000..d13a43b223e61
--- /dev/null
+++ b/mlir/test/Dialect/SCF/for-loop-peeling.mlir
@@ -0,0 +1,155 @@
+// RUN: mlir-opt %s -for-loop-peeling -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (s1 - s0) mod s2)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0, s1, s2] -> (s0, s2 - (s2 - (s2 - s1) mod s0))>
+// CHECK: func @fully_dynamic_bounds(
+// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index
+// CHECK: %[[C0_I32:.*]] = constant 0 : i32
+// CHECK: %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[LB]], %[[UB]], %[[STEP]]]
+// CHECK: %[[LOOP:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[NEW_UB]]
+// CHECK-SAME: step %[[STEP]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
+// CHECK: %[[MINOP:.*]] = affine.min #[[MAP1]](%[[IV]])[%[[STEP]], %[[UB]]]
+// CHECK: %[[CAST:.*]] = index_cast %[[MINOP]] : index to i32
+// CHECK: %[[ADD:.*]] = addi %[[ACC]], %[[CAST]] : i32
+// CHECK: scf.yield %[[ADD]]
+// CHECK: }
+// CHECK: %[[HAS_MORE:.*]] = cmpi slt, %[[NEW_UB]], %[[UB]]
+// CHECK: %[[RESULT:.*]] = scf.if %[[HAS_MORE]] -> (i32) {
+// CHECK: %[[REM:.*]] = affine.min #[[MAP2]]()[%[[STEP]], %[[LB]], %[[UB]]]
+// CHECK: %[[CAST2:.*]] = index_cast %[[REM]]
+// CHECK: %[[ADD2:.*]] = addi %[[LOOP]], %[[CAST2]]
+// CHECK: scf.yield %[[ADD2]]
+// CHECK: } else {
+// CHECK: scf.yield %[[LOOP]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func @fully_dynamic_bounds(%lb : index, %ub: index, %step: index) -> i32 {
+ %c0 = constant 0 : i32
+ %r = scf.for %iv = %lb to %ub step %step iter_args(%arg = %c0) -> i32 {
+ %s = affine.min #map(%ub, %iv)[%step]
+ %casted = index_cast %s : index to i32
+ %0 = addi %arg, %casted : i32
+ scf.yield %0 : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0) -> (4, -d0 + 17)>
+// CHECK: func @fully_static_bounds(
+// CHECK-DAG: %[[C0_I32:.*]] = constant 0 : i32
+// CHECK-DAG: %[[C1_I32:.*]] = constant 1 : i32
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C4:.*]] = constant 4 : index
+// CHECK-DAG: %[[C16:.*]] = constant 16 : index
+// CHECK: %[[LOOP:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C16]]
+// CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
+// CHECK: %[[MINOP:.*]] = affine.min #[[MAP]](%[[IV]])
+// CHECK: %[[CAST:.*]] = index_cast %[[MINOP]] : index to i32
+// CHECK: %[[ADD:.*]] = addi %[[ACC]], %[[CAST]] : i32
+// CHECK: scf.yield %[[ADD]]
+// CHECK: }
+// CHECK: %[[RESULT:.*]] = addi %[[LOOP]], %[[C1_I32]] : i32
+// CHECK: return %[[RESULT]]
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func @fully_static_bounds() -> i32 {
+ %c0_i32 = constant 0 : i32
+ %lb = constant 0 : index
+ %step = constant 4 : index
+ %ub = constant 17 : index
+ %r = scf.for %iv = %lb to %ub step %step
+ iter_args(%arg = %c0_i32) -> i32 {
+ %s = affine.min #map(%ub, %iv)[%step]
+ %casted = index_cast %s : index to i32
+ %0 = addi %arg, %casted : i32
+ scf.yield %0 : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> ((s0 floordiv 4) * 4)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (4, s0 mod 4)>
+// CHECK: func @dynamic_upper_bound(
+// CHECK-SAME: %[[UB:.*]]: index
+// CHECK-DAG: %[[C0_I32:.*]] = constant 0 : i32
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C4:.*]] = constant 4 : index
+// CHECK: %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[UB]]]
+// CHECK: %[[LOOP:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[NEW_UB]]
+// CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) {
+// CHECK: %[[MINOP:.*]] = affine.min #[[MAP1]](%[[IV]])[%[[UB]]]
+// CHECK: %[[CAST:.*]] = index_cast %[[MINOP]] : index to i32
+// CHECK: %[[ADD:.*]] = addi %[[ACC]], %[[CAST]] : i32
+// CHECK: scf.yield %[[ADD]]
+// CHECK: }
+// CHECK: %[[HAS_MORE:.*]] = cmpi slt, %[[NEW_UB]], %[[UB]]
+// CHECK: %[[RESULT:.*]] = scf.if %[[HAS_MORE]] -> (i32) {
+// CHECK: %[[REM:.*]] = affine.min #[[MAP2]]()[%[[UB]]]
+// CHECK: %[[CAST2:.*]] = index_cast %[[REM]]
+// CHECK: %[[ADD2:.*]] = addi %[[LOOP]], %[[CAST2]]
+// CHECK: scf.yield %[[ADD2]]
+// CHECK: } else {
+// CHECK: scf.yield %[[LOOP]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func @dynamic_upper_bound(%ub : index) -> i32 {
+ %c0_i32 = constant 0 : i32
+ %lb = constant 0 : index
+ %step = constant 4 : index
+ %r = scf.for %iv = %lb to %ub step %step
+ iter_args(%arg = %c0_i32) -> i32 {
+ %s = affine.min #map(%ub, %iv)[%step]
+ %casted = index_cast %s : index to i32
+ %0 = addi %arg, %casted : i32
+ scf.yield %0 : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> ((s0 floordiv 4) * 4)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (4, s0 mod 4)>
+// CHECK: func @no_loop_results(
+// CHECK-SAME: %[[UB:.*]]: index, %[[MEMREF:.*]]: memref<i32>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C4:.*]] = constant 4 : index
+// CHECK: %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[UB]]]
+// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[NEW_UB]] step %[[C4]] {
+// CHECK: %[[MINOP:.*]] = affine.min #[[MAP1]](%[[IV]])[%[[UB]]]
+// CHECK: %[[LOAD:.*]] = memref.load %[[MEMREF]][]
+// CHECK: %[[CAST:.*]] = index_cast %[[MINOP]] : index to i32
+// CHECK: %[[ADD:.*]] = addi %[[LOAD]], %[[CAST]] : i32
+// CHECK: memref.store %[[ADD]], %[[MEMREF]]
+// CHECK: }
+// CHECK: %[[HAS_MORE:.*]] = cmpi slt, %[[NEW_UB]], %[[UB]]
+// CHECK: scf.if %[[HAS_MORE]] {
+// CHECK: %[[REM:.*]] = affine.min #[[MAP2]]()[%[[UB]]]
+// CHECK: %[[LOAD2:.*]] = memref.load %[[MEMREF]][]
+// CHECK: %[[CAST2:.*]] = index_cast %[[REM]]
+// CHECK: %[[ADD2:.*]] = addi %[[LOAD2]], %[[CAST2]]
+// CHECK: memref.store %[[ADD2]], %[[MEMREF]]
+// CHECK: }
+// CHECK: return
+#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
+func @no_loop_results(%ub : index, %d : memref<i32>) {
+ %c0_i32 = constant 0 : i32
+ %lb = constant 0 : index
+ %step = constant 4 : index
+ scf.for %iv = %lb to %ub step %step {
+ %s = affine.min #map(%ub, %iv)[%step]
+ %r = memref.load %d[] : memref<i32>
+ %casted = index_cast %s : index to i32
+ %0 = addi %r, %casted : i32
+ memref.store %0, %d[] : memref<i32>
+ }
+ return
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index ecc341847a95e..6a684c06c918a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1479,6 +1479,7 @@ cc_library(
includes = ["include"],
deps = [
":Affine",
+ ":DialectUtils",
":IR",
":MemRefDialect",
":Pass",
More information about the Mlir-commits
mailing list