[Mlir-commits] [mlir] f6f88e6 - [mlir] Add software pipelining transformation for scf.For op
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 19 13:44:18 PDT 2021
Author: thomasraoux
Date: 2021-07-19T13:43:26-07:00
New Revision: f6f88e66cedcebbabfaf7e68344d46835d9edc1d
URL: https://github.com/llvm/llvm-project/commit/f6f88e66cedcebbabfaf7e68344d46835d9edc1d
DIFF: https://github.com/llvm/llvm-project/commit/f6f88e66cedcebbabfaf7e68344d46835d9edc1d.diff
LOG: [mlir] Add software pipelining transformation for scf.For op
This is the first step to support software pipeline for scf.for loops.
This is only the transformation to create pipelined kernel and
prologue/epilogue.
The scheduling needs to be given by user as many different algorithm
and heuristic could be applied.
This currently doesn't handle loop arguments, this will be added in a
follow up patch.
Differential Revision: https://reviews.llvm.org/D105868
Added:
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
mlir/test/Dialect/SCF/loop-pipelining.mlir
Modified:
mlir/include/mlir/Dialect/SCF/Transforms.h
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h
index 8ab13e42f477d..e1b0881d4af04 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms.h
@@ -23,10 +23,12 @@ class Region;
class TypeConverter;
class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
+class Operation;
namespace scf {
class ParallelOp;
+class ForOp;
/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
@@ -64,6 +66,39 @@ void populateSCFStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);
+/// Options to dictate how loops should be pipelined.
+struct PipeliningOption {
+ /// Lambda returning all the operation in the forOp, with their stage, in the
+ /// order picked for the pipelined loop.
+ using GetScheduleFnType = std::function<void(
+ scf::ForOp, std::vector<std::pair<Operation *, unsigned>> &)>;
+ GetScheduleFnType getScheduleFn;
+ // TODO: add option to decide if the prologue/epilogue should be peeled.
+};
+
+/// Populate patterns for SCF software pipelining transformation.
+/// This transformation generates the pipelined loop and doesn't do any
+/// assumptions on the schedule dictated by the option structure.
+/// Software pipelining is usually done in two part. The first part of
+/// pipelining is to schedule the loop and assign a stage and cycle to each
+/// operations. This is highly dependent on the target and is implemented as an
+/// heuristic based on operation latencies, and other hardware characteristics.
+/// The second part is to take the schedule and generate the pipelined loop as
+/// well as the prologue and epilogue. It is independent of the target.
+/// This pattern only implement the second part.
+/// For example if we break a loop into 3 stages named S0, S1, S2 we would
+/// generate the following code with the number in parenthesis the iteration
+/// index:
+/// S0(0) // Prologue
+/// S0(1) S1(0) // Prologue
+/// scf.for %I = %C0 to %N - 2 {
+/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel
+/// }
+/// S1(N) S2(N-1) // Epilogue
+/// S2(N) // Epilogue
+void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
+ const PipeliningOption &options);
+
} // namespace scf
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 689b53762a673..81d30581fa4e5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRSCFTransforms
Bufferize.cpp
+ LoopPipelining.cpp
LoopRangeFolding.cpp
LoopSpecialization.cpp
ParallelLoopFusion.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
new file mode 100644
index 0000000000000..7cb36c958f4de
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -0,0 +1,385 @@
+//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop software pipelining
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/SCF/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+namespace {
+
+/// Helper to keep internal information during pipelining transformation.
+struct LoopPipelinerInternal {
+ /// Coarse liverange information for ops used across stages.
+ struct LiverangeInfo {
+ unsigned lastUseStage = 0;
+ unsigned defStage = 0;
+ };
+
+protected:
+ ForOp forOp;
+ unsigned maxStage = 0;
+ DenseMap<Operation *, unsigned> stages;
+ std::vector<Operation *> opOrder;
+ int64_t ub;
+ int64_t lb;
+ int64_t step;
+
+ // When peeling the kernel we generate several version of each value for
+ //
diff erent stage of the prologue. This map tracks the mapping between
+ // original Values in the loop and the
diff erent versions
+ // peeled from the loop.
+ DenseMap<Value, llvm::SmallVector<Value>> valueMapping;
+
+ /// Assign a value to `valueMapping`, this means `val` represents the version
+ /// `idx` of `key` in the epilogue.
+ void setValueMapping(Value key, Value el, int64_t idx);
+
+public:
+ /// Initalize the information for the given `op`, return true if it
+ /// satisfies the pre-condition to apply pipelining.
+ bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
+ /// Emits the prologue, this creates `maxStage - 1` part which will contain
+ /// operations from stages [0; i], where i is the part index.
+ void emitPrologue(PatternRewriter &rewriter);
+ /// Gather liverange information for Values that are used in a
diff erent stage
+ /// than its definition.
+ llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
+ scf::ForOp createKernelLoop(
+ const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
+ PatternRewriter &rewriter,
+ llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
+ /// Emits the pipelined kernel. This clones loop operations following user
+ /// order and remaps operands defined in a
diff erent stage as their use.
+ void createKernel(
+ scf::ForOp newForOp,
+ const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
+ const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
+ PatternRewriter &rewriter);
+ /// Emits the epilogue, this creates `maxStage - 1` part which will contain
+ /// operations from stages [i; maxStage], where i is the part index.
+ void emitEpilogue(PatternRewriter &rewriter);
+};
+
+bool LoopPipelinerInternal::initializeLoopInfo(
+ ForOp op, const PipeliningOption &options) {
+ forOp = op;
+ auto upperBoundCst = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
+ auto lowerBoundCst = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
+ auto stepCst = forOp.step().getDefiningOp<ConstantIndexOp>();
+ if (!upperBoundCst || !lowerBoundCst || !stepCst)
+ return false;
+ ub = upperBoundCst.getValue();
+ lb = lowerBoundCst.getValue();
+ step = stepCst.getValue();
+ int64_t numIteration = ceilDiv(ub - lb, step);
+ std::vector<std::pair<Operation *, unsigned>> schedule;
+ options.getScheduleFn(forOp, schedule);
+ if (schedule.empty())
+ return false;
+
+ opOrder.reserve(schedule.size());
+ for (auto &opSchedule : schedule) {
+ maxStage = std::max(maxStage, opSchedule.second);
+ stages[opSchedule.first] = opSchedule.second;
+ opOrder.push_back(opSchedule.first);
+ }
+ if (numIteration <= maxStage)
+ return false;
+
+ // All operations need to have a stage.
+ if (forOp
+ .walk([this](Operation *op) {
+ if (op != forOp.getOperation() && !isa<scf::YieldOp>(op) &&
+ stages.find(op) == stages.end())
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ })
+ .wasInterrupted())
+ return false;
+
+ // TODO: Add support for loop with operands.
+ if (forOp.getNumIterOperands() > 0)
+ return false;
+
+ return true;
+}
+
+void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
+ for (int64_t i = 0; i < maxStage; i++) {
+ // special handling for induction variable as the increment is implicit.
+ Value iv = rewriter.create<ConstantIndexOp>(forOp.getLoc(), lb + i);
+ setValueMapping(forOp.getInductionVar(), iv, i);
+ for (Operation *op : opOrder) {
+ if (stages[op] > i)
+ continue;
+ Operation *newOp = rewriter.clone(*op);
+ for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) {
+ auto it = valueMapping.find(op->getOperand(opIdx));
+ if (it != valueMapping.end())
+ newOp->setOperand(opIdx, it->second[i - stages[op]]);
+ }
+ for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
+ setValueMapping(op->getResult(destId), newOp->getResult(destId),
+ i - stages[op]);
+ }
+ }
+ }
+}
+
+llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
+LoopPipelinerInternal::analyzeCrossStageValues() {
+ llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
+ for (Operation *op : opOrder) {
+ unsigned stage = stages[op];
+ for (OpOperand &operand : op->getOpOperands()) {
+ Operation *def = operand.get().getDefiningOp();
+ if (!def)
+ continue;
+ auto defStage = stages.find(def);
+ if (defStage == stages.end() || defStage->second == stage)
+ continue;
+ assert(stage > defStage->second);
+ LiverangeInfo &info = crossStageValues[operand.get()];
+ info.defStage = defStage->second;
+ info.lastUseStage = std::max(info.lastUseStage, stage);
+ }
+ }
+ return crossStageValues;
+}
+
+scf::ForOp LoopPipelinerInternal::createKernelLoop(
+ const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
+ &crossStageValues,
+ PatternRewriter &rewriter,
+ llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
+ // Creates the list of initial values associated to values used across
+ // stages. The initial values come from the prologue created above.
+ // Keep track of the kernel argument associated to each version of the
+ // values passed to the kernel.
+ auto newLoopArg = llvm::to_vector<8>(forOp.getIterOperands());
+ for (auto escape : crossStageValues) {
+ LiverangeInfo &info = escape.second;
+ Value value = escape.first;
+ for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
+ stageIdx++) {
+ Value valueVersion =
+ valueMapping[value][maxStage - info.lastUseStage + stageIdx];
+ assert(valueVersion);
+ newLoopArg.push_back(valueVersion);
+ loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
+ stageIdx)] = newLoopArg.size() - 1;
+ }
+ }
+
+ // Create the new kernel loop. Since we need to peel `numStages - 1`
+ // iteration we change the upper bound to remove those iterations.
+ Value newUb =
+ rewriter.create<ConstantIndexOp>(forOp.getLoc(), ub - maxStage * step);
+ auto newForOp = rewriter.create<scf::ForOp>(
+ forOp.getLoc(), forOp.lowerBound(), newUb, forOp.step(), newLoopArg);
+ return newForOp;
+}
+
+void LoopPipelinerInternal::createKernel(
+ scf::ForOp newForOp,
+ const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
+ &crossStageValues,
+ const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
+ PatternRewriter &rewriter) {
+ valueMapping.clear();
+
+ // Create the kernel, we clone instruction based on the order given by
+ // user and remap operands coming from a previous stages.
+ rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
+ BlockAndValueMapping mapping;
+ mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
+ for (Operation *op : opOrder) {
+ int64_t useStage = stages[op];
+ auto *newOp = rewriter.clone(*op, mapping);
+ for (OpOperand &operand : op->getOpOperands()) {
+ // Special case for the induction variable uses. We replace it with a
+ // version incremented based on the stage where it is used.
+ if (operand.get() == forOp.getInductionVar()) {
+ rewriter.setInsertionPoint(newOp);
+ Value offset = rewriter.create<ConstantIndexOp>(
+ forOp.getLoc(), (maxStage - stages[op]) * step);
+ Value iv = rewriter.create<AddIOp>(forOp.getLoc(),
+ newForOp.getInductionVar(), offset);
+ newOp->setOperand(operand.getOperandNumber(), iv);
+ rewriter.setInsertionPointAfter(newOp);
+ continue;
+ }
+ // For operands defined in a previous stage we need to remap it to use
+ // the correct region argument. We look for the right version of the
+ // Value based on the stage where it is used.
+ Operation *def = operand.get().getDefiningOp();
+ if (!def)
+ continue;
+ auto stageDef = stages.find(def);
+ if (stageDef == stages.end() || stageDef->second == useStage)
+ continue;
+ auto remap = loopArgMap.find(
+ std::make_pair(operand.get(), useStage - stageDef->second));
+ assert(remap != loopArgMap.end());
+ newOp->setOperand(operand.getOperandNumber(),
+ newForOp.getRegionIterArgs()[remap->second]);
+ }
+ }
+
+ // Collect the Values that need to be returned by the forOp. For each
+ // value we need to have `LastUseStage - DefStage` number of versions
+ // returned.
+ // We create a mapping between original values and the associated loop
+ // returned values that will be needed by the epilogue.
+ llvm::SmallVector<Value> yieldOperands;
+ for (auto &it : crossStageValues) {
+ int64_t version = maxStage - it.second.lastUseStage + 1;
+ unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
+ // add the original verstion to yield ops.
+ // If there is a liverange spanning across more than 2 stages we need to add
+ // extra arg.
+ for (unsigned i = 1; i < numVersionReturned; i++) {
+ setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
+ version++);
+ yieldOperands.push_back(
+ newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
+ newForOp.getNumInductionVars()]);
+ }
+ setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
+ version++);
+ yieldOperands.push_back(mapping.lookupOrDefault(it.first));
+ }
+ rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
+}
+
+void LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
+ // Emit
diff erent versions of the induction variable. They will be
+ // removed by dead code if not used.
+ for (int64_t i = 0; i < maxStage; i++) {
+ Value newlastIter = rewriter.create<ConstantIndexOp>(
+ forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i));
+ setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+ }
+ // Emit `maxStage - 1` epilogue part that includes operations fro stages
+ // [i; maxStage].
+ for (int64_t i = 1; i <= maxStage; i++) {
+ for (Operation *op : opOrder) {
+ if (stages[op] < i)
+ continue;
+ Operation *newOp = rewriter.clone(*op);
+ for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) {
+ auto it = valueMapping.find(op->getOperand(opIdx));
+ if (it != valueMapping.end()) {
+ Value v = it->second[maxStage - stages[op] + i];
+ assert(v);
+ newOp->setOperand(opIdx, v);
+ }
+ }
+ for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
+ setValueMapping(op->getResult(destId), newOp->getResult(destId),
+ maxStage - stages[op] + i);
+ }
+ }
+ }
+}
+
+void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
+ auto it = valueMapping.find(key);
+ // If the value is not in the map yet add a vector big enough to store all
+ // versions.
+ if (it == valueMapping.end())
+ it =
+ valueMapping
+ .insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1)))
+ .first;
+ it->second[idx] = el;
+}
+
+/// Generate a pipelined version of the scf.for loop based on the schedule given
+/// as option. This applies the mechanical transformation of changing the loop
+/// and generating the prologue/epilogue for the pipelining and doesn't make any
+/// decision regarding the schedule.
+/// Based on the option the loop is split into several stages.
+/// The transformation assumes that the scheduling given by user is valid.
+/// For example if we break a loop into 3 stages named S0, S1, S2 we would
+/// generate the following code with the number in parenthesis the iteration
+/// index:
+/// S0(0) // Prologue
+/// S0(1) S1(0) // Prologue
+/// scf.for %I = %C0 to %N - 2 {
+/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel
+/// }
+/// S1(N) S2(N-1) // Epilogue
+/// S2(N) // Epilogue
+struct ForLoopPipelining : public OpRewritePattern<ForOp> {
+ ForLoopPipelining(const PipeliningOption &options, MLIRContext *context)
+ : OpRewritePattern<ForOp>(context), options(options) {}
+ LogicalResult matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const override {
+
+ LoopPipelinerInternal pipeliner;
+ if (!pipeliner.initializeLoopInfo(forOp, options))
+ return failure();
+
+ // 1. Emit prologue.
+ pipeliner.emitPrologue(rewriter);
+
+ // 2. Track values used across stages. When a value cross stages it will
+ // need to be passed as loop iteration arguments.
+ // We first collect the values that are used in a
diff erent stage than where
+ // they are defined.
+ llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
+ crossStageValues = pipeliner.analyzeCrossStageValues();
+
+ // Mapping between original loop values used cross stage and the block
+ // arguments associated after pipelining. A Value may map to several
+ // arguments if its liverange spans across more than 2 stages.
+ llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
+ // 3. Create the new kernel loop and return the block arguments mapping.
+ ForOp newForOp =
+ pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
+ // Create the kernel block, order ops based on user choice and remap
+ // operands.
+ pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter);
+
+ // 4. Emit the epilogue after the new forOp.
+ rewriter.setInsertionPointAfter(newForOp);
+ pipeliner.emitEpilogue(rewriter);
+
+ // 5. Erase the original loop and replace the uses with the epilogue output.
+ if (forOp->getNumResults() > 0)
+ rewriter.replaceOp(
+ forOp, newForOp.getResults().take_front(forOp->getNumResults()));
+ else
+ rewriter.eraseOp(forOp);
+
+ return success();
+ }
+
+protected:
+ PipeliningOption options;
+};
+
+} // namespace
+
+void mlir::scf::populateSCFLoopPipeliningPatterns(
+ RewritePatternSet &patterns, const PipeliningOption &options) {
+ patterns.add<ForLoopPipelining>(options, patterns.getContext());
+}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
new file mode 100644
index 0000000000000..fb3cce1ed7869
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -0,0 +1,173 @@
+// RUN: mlir-opt %s -test-scf-pipelining -split-input-file | FileCheck %s
+
+// CHECK-LABEL: simple_pipeline(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>) {
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[C3:.*]] = constant 3 : index
+// Prologue:
+// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// Kernel:
+// CHECK-NEXT: %[[L1:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
+// CHECK-SAME: step %[[C1]] iter_args(%[[LARG:.*]] = %[[L0]]) -> (f32) {
+// CHECK-NEXT: %[[ADD0:.*]] = addf %[[LARG]], %{{.*}} : f32
+// CHECK-NEXT: memref.store %[[ADD0]], %[[R]][%[[IV]]] : memref<?xf32>
+// CHECK-NEXT: %[[IV1:.*]] = addi %[[IV]], %[[C1]] : index
+// CHECK-NEXT: %[[LR:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
+// CHECK-NEXT: scf.yield %[[LR]] : f32
+// CHECK-NEXT: }
+// Epilogue:
+// CHECK-NEXT: %[[ADD1:.*]] = addf %[[L1]], %{{.*}} : f32
+// CHECK-NEXT: memref.store %[[ADD1]], %[[R]][%[[C3]]] : memref<?xf32>
+func @simple_pipeline(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %cf = constant 1.0 : f32
+ scf.for %i0 = %c0 to %c4 step %c1 {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ %A1_elem = addf %A_elem, %cf { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+ memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : memref<?xf32>
+ } { __test_pipelining_loop__ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: three_stage(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>) {
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = constant 3 : index
+// Prologue:
+// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// CHECK-NEXT: %[[ADD0:.*]] = addf %[[L0]], %{{.*}} : f32
+// CHECK-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref<?xf32>
+// Kernel:
+// CHECK-NEXT: %[[LR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
+// CHECK-SAME: step %[[C1]] iter_args(%[[ADDARG:.*]] = %[[ADD0]],
+// CHECK-SAME: %[[LARG:.*]] = %[[L1]]) -> (f32, f32) {
+// CHECK-NEXT: memref.store %[[ADDARG]], %[[R]][%[[IV]]] : memref<?xf32>
+// CHECK-NEXT: %[[ADD1:.*]] = addf %[[LARG]], %{{.*}} : f32
+// CHECK-NEXT: %[[IV2:.*]] = addi %[[IV]], %[[C2]] : index
+// CHECK-NEXT: %[[L3:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
+// CHECK-NEXT: scf.yield %[[ADD1]], %[[L3]] : f32, f32
+// CHECK-NEXT: }
+// Epilogue:
+// CHECK-NEXT: memref.store %[[LR]]#0, %[[R]][%[[C2]]] : memref<?xf32>
+// CHECK-NEXT: %[[ADD2:.*]] = addf %[[LR]]#1, %{{.*}} : f32
+// CHECK-NEXT: memref.store %[[ADD2]], %[[R]][%[[C3]]] : memref<?xf32>
+func @three_stage(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %cf = constant 1.0 : f32
+ scf.for %i0 = %c0 to %c4 step %c1 {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ %A1_elem = addf %A_elem, %cf { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+ memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : memref<?xf32>
+ } { __test_pipelining_loop__ }
+ return
+}
+
+// -----
+// CHECK-LABEL: long_liverange(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>) {
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = constant 3 : index
+// CHECK-DAG: %[[C4:.*]] = constant 4 : index
+// CHECK-DAG: %[[C6:.*]] = constant 6 : index
+// CHECK-DAG: %[[C7:.*]] = constant 7 : index
+// CHECK-DAG: %[[C8:.*]] = constant 8 : index
+// CHECK-DAG: %[[C9:.*]] = constant 9 : index
+// Prologue:
+// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// CHECK-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref<?xf32>
+// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[C2]]] : memref<?xf32>
+// CHECK-NEXT: %[[L3:.*]] = memref.load %[[A]][%[[C3]]] : memref<?xf32>
+// Kernel:
+// CHECK-NEXT: %[[LR:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]] to %[[C6]]
+// CHECK-SAME: step %[[C1]] iter_args(%[[LA0:.*]] = %[[L0]],
+// CHECK-SAME: %[[LA1:.*]] = %[[L1]], %[[LA2:.*]] = %[[L2]],
+// CHECK-SAME: %[[LA3:.*]] = %[[L3]]) -> (f32, f32, f32, f32) {
+// CHECK-NEXT: %[[ADD0:.*]] = addf %[[LA0]], %{{.*}} : f32
+// CHECK-NEXT: memref.store %[[ADD0]], %[[R]][%[[IV]]] : memref<?xf32>
+// CHECK-NEXT: %[[IV4:.*]] = addi %[[IV]], %[[C4]] : index
+// CHECK-NEXT: %[[L4:.*]] = memref.load %[[A]][%[[IV4]]] : memref<?xf32>
+// CHECK-NEXT: scf.yield %[[LA1]], %[[LA2]], %[[LA3]], %[[L4]] : f32, f32, f32, f32
+// CHECK-NEXT: }
+// Epilogue:
+// CHECK-NEXT: %[[ADD1:.*]] = addf %[[LR]]#0, %{{.*}} : f32
+// CHECK-NEXT: memref.store %[[ADD1]], %[[R]][%[[C6]]] : memref<?xf32>
+// CHECK-NEXT: %[[ADD2:.*]] = addf %[[LR]]#1, %{{.*}} : f32
+// CHECK-NEXT: memref.store %[[ADD2]], %[[R]][%[[C7]]] : memref<?xf32>
+// CHECK-NEXT: %[[ADD3:.*]] = addf %[[LR]]#2, %{{.*}} : f32
+// CHECK-NEXT: memref.store %[[ADD3]], %[[R]][%[[C8]]] : memref<?xf32>
+// CHECK-NEXT: %[[ADD4:.*]] = addf %[[LR]]#3, %{{.*}} : f32
+// CHECK-NEXT: memref.store %[[ADD4]], %[[R]][%[[C9]]] : memref<?xf32>
+func @long_liverange(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c10 = constant 10 : index
+ %cf = constant 1.0 : f32
+ scf.for %i0 = %c0 to %c10 step %c1 {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ %A1_elem = addf %A_elem, %cf { __test_pipelining_stage__ = 4, __test_pipelining_op_order__ = 0 } : f32
+ memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 4, __test_pipelining_op_order__ = 1 } : memref<?xf32>
+ } { __test_pipelining_loop__ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: multiple_uses(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>) {
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = constant 3 : index
+// CHECK-DAG: %[[C7:.*]] = constant 7 : index
+// CHECK-DAG: %[[C8:.*]] = constant 8 : index
+// CHECK-DAG: %[[C9:.*]] = constant 9 : index
+// Prologue:
+// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// CHECK-NEXT: %[[ADD0:.*]] = addf %[[L0]], %{{.*}} : f32
+// CHECK-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref<?xf32>
+// CHECK-NEXT: %[[ADD1:.*]] = addf %[[L1]], %{{.*}} : f32
+// CHECK-NEXT: %[[MUL0:.*]] = mulf %[[ADD0]], %[[L0]] : f32
+// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[C2]]] : memref<?xf32>
+// Kernel:
+// CHECK-NEXT: %[[LR:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]] to %[[C7]]
+// CHECK-SAME: step %[[C1]] iter_args(%[[LA1:.*]] = %[[L1]],
+// CHECK-SAME: %[[LA2:.*]] = %[[L2]], %[[ADDARG1:.*]] = %[[ADD1]],
+// CHECK-SAME: %[[MULARG0:.*]] = %[[MUL0]]) -> (f32, f32, f32, f32) {
+// CHECK-NEXT: %[[ADD2:.*]] = addf %[[LA2]], %{{.*}} : f32
+// CHECK-NEXT: %[[MUL1:.*]] = mulf %[[ADDARG1]], %[[LA1]] : f32
+// CHECK-NEXT: memref.store %[[MULARG0]], %[[R]][%[[IV]]] : memref<?xf32>
+// CHECK-NEXT: %[[IV3:.*]] = addi %[[IV]], %[[C3]] : index
+// CHECK-NEXT: %[[L3:.*]] = memref.load %[[A]][%[[IV3]]] : memref<?xf32>
+// CHECK-NEXT: scf.yield %[[LA2]], %[[L3]], %[[ADD2]], %[[MUL1]] : f32, f32, f32, f32
+// CHECK-NEXT: }
+// Epilogue:
+// CHECK-NEXT: %[[ADD3:.*]] = addf %[[LR]]#1, %{{.*}} : f32
+// CHECK-NEXT: %[[MUL2:.*]] = mulf %[[LR]]#2, %[[LR]]#0 : f32
+// CHECK-NEXT: memref.store %[[LR]]#3, %[[R]][%[[C7]]] : memref<?xf32>
+// CHECK-NEXT: %[[MUL3:.*]] = mulf %[[ADD3]], %[[LR]]#1 : f32
+// CHECK-NEXT: memref.store %[[MUL2]], %[[R]][%[[C8]]] : memref<?xf32>
+// CHECK-NEXT: memref.store %[[MUL3]], %[[R]][%[[C9]]] : memref<?xf32>
+func @multiple_uses(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c10 = constant 10 : index
+ %cf = constant 1.0 : f32
+ scf.for %i0 = %c0 to %c10 step %c1 {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>
+ %A1_elem = addf %A_elem, %cf { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+ %A2_elem = mulf %A1_elem, %A_elem { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 1 } : f32
+ memref.store %A2_elem, %result[%i0] { __test_pipelining_stage__ = 3, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ } { __test_pipelining_loop__ }
+ return
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 45fd2a1805621..2254c54b26a11 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -11,9 +11,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SetVector.h"
@@ -72,6 +75,54 @@ class TestSCFIfUtilsPass
});
}
};
+
+static const StringLiteral kTestPipeliningLoopMarker =
+ "__test_pipelining_loop__";
+static const StringLiteral kTestPipeliningStageMarker =
+ "__test_pipelining_stage__";
+/// Marker to express the order in which operations should be after pipelining.
+static const StringLiteral kTestPipeliningOpOrderMarker =
+ "__test_pipelining_op_order__";
+
+class TestSCFPipeliningPass
+ : public PassWrapper<TestSCFPipeliningPass, FunctionPass> {
+public:
+ StringRef getArgument() const final { return "test-scf-pipelining"; }
+ StringRef getDescription() const final { return "test scf.forOp pipelining"; }
+ explicit TestSCFPipeliningPass() = default;
+
+ static void
+ getSchedule(scf::ForOp forOp,
+ std::vector<std::pair<Operation *, unsigned>> &schedule) {
+ if (!forOp->hasAttr(kTestPipeliningLoopMarker))
+ return;
+ schedule.resize(forOp.getBody()->getOperations().size() - 1);
+ forOp.walk([&schedule](Operation *op) {
+ auto attrStage =
+ op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
+ auto attrCycle =
+ op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
+ if (attrCycle && attrStage) {
+ schedule[attrCycle.getInt()] =
+ std::make_pair(op, unsigned(attrStage.getInt()));
+ }
+ });
+ }
+
+ void runOnFunction() override {
+ RewritePatternSet patterns(&getContext());
+ mlir::scf::PipeliningOption options;
+ options.getScheduleFn = getSchedule;
+
+ scf::populateSCFLoopPipeliningPatterns(patterns, options);
+ (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+ getFunction().walk([](Operation *op) {
+ // Clean up the markers.
+ op->removeAttr(kTestPipeliningStageMarker);
+ op->removeAttr(kTestPipeliningOpOrderMarker);
+ });
+ }
+};
} // namespace
namespace mlir {
@@ -79,6 +130,7 @@ namespace test {
void registerTestSCFUtilsPass() {
PassRegistration<TestSCFForUtilsPass>();
PassRegistration<TestSCFIfUtilsPass>();
+ PassRegistration<TestSCFPipeliningPass>();
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list