[Mlir-commits] [mlir] [Affine] Create memory banks for memories used inside affine parallel loops (PR #115759)
Jiahan Xie
llvmlistbot at llvm.org
Mon Nov 11 11:27:18 PST 2024
https://github.com/jiahanxie353 created https://github.com/llvm/llvm-project/pull/115759
This patch partition memories used inside affine parallel loops into even banks and replace the old memrefs with new ones throughout the program.
The motivation is that although it's "parallel", each physical memory port can only handle one read/write per cycle so we allow the users to manually partition a logical memories into several physical memories based on their banking logic.
I choose to create a pass in `affine::parallel` instead of `scf::parallel` because it's useful to do affine analysis after splitting the original memory into banks.
Currently, the pass only support one-dimensional memories
>From 045d377b7aedc14e7d2cc272adec9a43f46c4661 Mon Sep 17 00:00:00 2001
From: Jiahan Xie <jx353 at cornell.edu>
Date: Sun, 10 Nov 2024 11:20:26 -0500
Subject: [PATCH 1/6] something partially working without rewriter, erase after
walk traversal
---
mlir/include/mlir/Dialect/Affine/Passes.h | 5 +
mlir/include/mlir/Dialect/Affine/Passes.td | 9 +
.../Dialect/Affine/Transforms/CMakeLists.txt | 2 +
.../Affine/Transforms/ParallelUnroll.cpp | 352 ++++++++++++++++++
4 files changed, 368 insertions(+)
create mode 100644 mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index 61f24255f305f7..53a941acff1d9e 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -25,6 +25,7 @@ class FuncOp;
namespace affine {
class AffineForOp;
+class AffineParallelOp;
/// Fusion mode to attempt. The default mode `Greedy` does both
/// producer-consumer and sibling fusion.
@@ -108,6 +109,10 @@ std::unique_ptr<OperationPass<func::FuncOp>> createLoopUnrollPass(
std::unique_ptr<OperationPass<func::FuncOp>>
createLoopUnrollAndJamPass(int unrollJamFactor = -1);
+std::unique_ptr<OperationPass<func::FuncOp>> createParallelUnrollPass(
+ int unrollFactor = -1,
+ const std::function<unsigned(AffineParallelOp)> &getUnrollFactor = nullptr);
+
/// Creates a pass to pipeline explicit movement of data across levels of the
/// memory hierarchy.
std::unique_ptr<OperationPass<func::FuncOp>> createPipelineDataTransferPass();
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index b08e803345f76e..55a3583b90a55b 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -381,6 +381,15 @@ def AffineParallelize : Pass<"affine-parallelize", "func::FuncOp"> {
];
}
+def AffineParallelUnroll : Pass<"affine-parallel-unroll", "func::FuncOp"> {
+ let summary = "Unroll affine parallel loops";
+ let constructor = "mlir::affine::createParallelUnrollPass()";
+ let options = [
+ Option<"unrollFactor", "unroll-factor", "unsigned", /*default=*/"1",
+ "Use this unroll factor for all loops being unrolled">
+ ];
+}
+
def AffineLoopNormalize : Pass<"affine-loop-normalize", "func::FuncOp"> {
let summary = "Apply normalization transformations to affine loop-like ops";
let constructor = "mlir::affine::createAffineLoopNormalizePass()";
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 772f15335d907f..c10d9e6a04f746 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
LoopTiling.cpp
LoopUnroll.cpp
LoopUnrollAndJam.cpp
+ ParallelUnroll.cpp
PipelineDataTransfer.cpp
ReifyValueBounds.cpp
SuperVectorize.cpp
@@ -39,5 +40,6 @@ add_mlir_dialect_library(MLIRAffineTransforms
MLIRValueBoundsOpInterface
MLIRVectorDialect
MLIRVectorUtils
+ MLIRSCFDialect
)
diff --git a/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
new file mode 100644
index 00000000000000..62afbb0dbb1fad
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
@@ -0,0 +1,352 @@
+//===- ParallelUnroll.cpp - Code to perform parallel loop unrolling
+//--------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements parallel loop unrolling.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cassert>
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_AFFINEPARALLELUNROLL
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+#define DEBUG_TYPE "affine-parallel-unroll"
+
+using namespace mlir;
+using namespace mlir::affine;
+
+namespace {
+
+/// Unroll an `affine.parallel` operation by the `unrollFactor` specified in the
+/// attribute. Evenly splitting `memref`s that are present in the `parallel`
+/// region into smaller banks.
+struct ParallelUnroll
+ : public affine::impl::AffineParallelUnrollBase<ParallelUnroll> {
+ const std::function<unsigned(AffineParallelOp)> getUnrollFactor;
+ ParallelUnroll() : getUnrollFactor(nullptr) {}
+ ParallelUnroll(const ParallelUnroll &other) = default;
+ explicit ParallelUnroll(std::optional<unsigned> unrollFactor = std::nullopt,
+ const std::function<unsigned(AffineParallelOp)>
+ &getUnrollFactor = nullptr)
+ : getUnrollFactor(getUnrollFactor) {
+ if (unrollFactor)
+ this->unrollFactor = *unrollFactor;
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mlir::scf::SCFDialect>();
+ }
+
+ void runOnOperation() override;
+ LogicalResult parallelUnrollByFactor(AffineParallelOp parOp,
+ uint64_t unrollFactor);
+
+private:
+ // map from original memory definition to newly allocated banks
+ DenseMap<Value, SmallVector<Value>> memoryToBanks;
+ SmallVector<Operation *, 8> opsToErase;
+};
+} // namespace
+
+// Collect all memref in the `parOp`'s region'
+DenseSet<Value> collectMemRefs(AffineParallelOp parOp) {
+ DenseSet<Value> memrefVals;
+ parOp.walk([&](Operation *op) {
+ for (auto operand : op->getOperands()) {
+ if (isa<MemRefType>(operand.getType()))
+ memrefVals.insert(operand);
+ }
+ return WalkResult::advance();
+ });
+ return memrefVals;
+}
+
+MemRefType computeBankedMemRefType(MemRefType originalType,
+ uint64_t bankingFactor) {
+ ArrayRef<int64_t> originalShape = originalType.getShape();
+ assert(!originalShape.empty() && "memref shape should not be empty");
+ assert(originalType.getRank() == 1 &&
+ "currently only support one dimension memories");
+ SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
+ assert(newShape.front() % bankingFactor == 0 &&
+ "memref shape must be divided by the banking factor");
+ // Now assuming banking the last dimension
+ newShape.front() /= bankingFactor;
+ MemRefType newMemRefType =
+ MemRefType::get(newShape, originalType.getElementType(),
+ originalType.getLayout(), originalType.getMemorySpace());
+
+ return newMemRefType;
+}
+
+SmallVector<Value> createBanks(Value originalMem, uint64_t unrollFactor) {
+ MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
+ MemRefType newMemRefType =
+ computeBankedMemRefType(originalMemRefType, unrollFactor);
+ SmallVector<Value, 4> banks;
+ if (auto blockArgMem = dyn_cast<BlockArgument>(originalMem)) {
+ Block *block = blockArgMem.getOwner();
+ unsigned blockArgNum = blockArgMem.getArgNumber();
+
+ SmallVector<Type> banksType;
+ for (unsigned i = 0; i < unrollFactor; ++i) {
+ block->insertArgument(blockArgNum + 1 + i, newMemRefType,
+ blockArgMem.getLoc());
+ }
+
+ auto blockArgs = block->getArguments().slice(blockArgNum + 1, unrollFactor);
+ banks.append(blockArgs.begin(), blockArgs.end());
+ } else {
+ Operation *originalDef = originalMem.getDefiningOp();
+ Location loc = originalDef->getLoc();
+ OpBuilder builder(originalDef);
+ builder.setInsertionPointAfter(originalDef);
+ TypeSwitch<Operation *>(originalDef)
+ .Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
+ for (uint bankCnt = 0; bankCnt < unrollFactor; bankCnt++) {
+ auto bankAllocOp =
+ builder.create<memref::AllocOp>(loc, newMemRefType);
+ banks.push_back(bankAllocOp);
+ }
+ })
+ .Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
+ for (uint bankCnt = 0; bankCnt < unrollFactor; bankCnt++) {
+ auto bankAllocaOp =
+ builder.create<memref::AllocaOp>(loc, newMemRefType);
+ banks.push_back(bankAllocaOp);
+ }
+ })
+ .Default([](Operation *) {
+ llvm_unreachable("Unhandled memory operation type");
+ });
+ }
+ return banks;
+}
+
+Value computeIntraBankingOffset(OpBuilder &builder, Location loc, Value address,
+ uint availableBanks) {
+ Value availBanksVal =
+ builder
+ .create<arith::ConstantOp>(loc, builder.getIndexAttr(availableBanks))
+ .getResult();
+ Value offset =
+ builder.create<arith::DivUIOp>(loc, address, availBanksVal).getResult();
+ return offset;
+}
+
+/// Unrolls a 'affine.parallel' op. Returns success if the loop was unrolled,
+/// failure otherwise. The default unroll factor is 4.
+LogicalResult ParallelUnroll::parallelUnrollByFactor(AffineParallelOp parOp,
+ uint64_t unrollFactor) {
+ // 1. identify memrefs in the parallel region,
+ // 2. create memory banks for each of those memories
+ // 2.1 maybe result of alloc/getglobal, etc
+ // 2.2 maybe block arguments
+ //
+
+ DenseSet<Value> memrefsInPar = collectMemRefs(parOp);
+ Location loc = parOp.getLoc();
+ OpBuilder builder(parOp);
+
+ DenseSet<Block *> blocksToModify;
+ for (auto memrefVal : memrefsInPar) {
+ SmallVector<Value> banks = createBanks(memrefVal, unrollFactor);
+ memoryToBanks[memrefVal] = banks;
+
+ for (auto *user : memrefVal.getUsers()) {
+ // if user is within parallel region
+ TypeSwitch<Operation *>(user)
+ .Case<affine::AffineLoadOp>([&](affine::AffineLoadOp loadOp) {
+ Value loadIndex = loadOp.getIndices().front();
+ builder.setInsertionPointToStart(parOp.getBody());
+ Value bankingFactorValue =
+ builder.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
+ Value bankIndex = builder.create<mlir::arith::RemUIOp>(
+ loc, loadIndex, bankingFactorValue);
+ Value offset = computeIntraBankingOffset(builder, loc, loadIndex,
+ unrollFactor);
+
+ SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
+
+ SmallVector<int64_t, 4> caseValues;
+ for (unsigned i = 0; i < unrollFactor; ++i)
+ caseValues.push_back(i);
+
+ builder.setInsertionPoint(user);
+ scf::IndexSwitchOp switchOp = builder.create<scf::IndexSwitchOp>(
+ loc, resultTypes, bankIndex, caseValues,
+ /*numRegions=*/unrollFactor);
+
+ for (unsigned i = 0; i < unrollFactor; ++i) {
+ Region &caseRegion = switchOp.getCaseRegions()[i];
+ builder.setInsertionPointToStart(&caseRegion.emplaceBlock());
+ Value bankedLoad =
+ builder.create<AffineLoadOp>(loc, banks[i], offset);
+ builder.create<scf::YieldOp>(loc, bankedLoad);
+ }
+
+ Region &defaultRegion = switchOp.getDefaultRegion();
+ assert(defaultRegion.empty() && "Default region should be empty");
+ builder.setInsertionPointToStart(&defaultRegion.emplaceBlock());
+
+ TypedAttr zeroAttr =
+ cast<TypedAttr>(builder.getZeroAttr(loadOp.getType()));
+ auto defaultValue =
+ builder.create<arith::ConstantOp>(loc, zeroAttr);
+ builder.create<scf::YieldOp>(loc, defaultValue.getResult());
+
+ loadOp.getResult().replaceAllUsesWith(switchOp.getResult(0));
+
+ user->erase();
+ })
+ .Case<affine::AffineStoreOp>([&](affine::AffineStoreOp storeOp) {
+ Value loadIndex = storeOp.getIndices().front();
+ builder.setInsertionPointToStart(parOp.getBody());
+ Value bankingFactorValue =
+ builder.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
+ Value bankIndex = builder.create<mlir::arith::RemUIOp>(
+ loc, loadIndex, bankingFactorValue);
+ Value offset = computeIntraBankingOffset(builder, loc, loadIndex,
+ unrollFactor);
+
+ SmallVector<Type> resultTypes = {};
+
+ SmallVector<int64_t, 4> caseValues;
+ for (unsigned i = 0; i < unrollFactor; ++i)
+ caseValues.push_back(i);
+
+ builder.setInsertionPoint(user);
+ scf::IndexSwitchOp switchOp = builder.create<scf::IndexSwitchOp>(
+ loc, resultTypes, bankIndex, caseValues,
+ /*numRegions=*/unrollFactor);
+
+ for (unsigned i = 0; i < unrollFactor; ++i) {
+ Region &caseRegion = switchOp.getCaseRegions()[i];
+ builder.setInsertionPointToStart(&caseRegion.emplaceBlock());
+ builder.create<AffineStoreOp>(loc, storeOp.getValueToStore(),
+ banks[i], offset);
+ builder.create<scf::YieldOp>(loc);
+ }
+
+ Region &defaultRegion = switchOp.getDefaultRegion();
+ assert(defaultRegion.empty() && "Default region should be empty");
+ builder.setInsertionPointToStart(&defaultRegion.emplaceBlock());
+
+ builder.create<scf::YieldOp>(loc);
+
+ user->erase();
+ })
+ .Default([](Operation *op) {
+ op->emitWarning("Unhandled operation type");
+ op->dump();
+ });
+ }
+
+ for (auto *user : memrefVal.getUsers()) {
+ if (auto returnOp = dyn_cast<func::ReturnOp>(user)) {
+ OpBuilder builder(returnOp);
+ func::FuncOp funcOp = returnOp.getParentOp();
+ builder.setInsertionPointToEnd(&funcOp.getBlocks().front());
+ auto newReturnOp =
+ builder.create<func::ReturnOp>(loc, ValueRange(banks));
+ TypeRange newReturnType = TypeRange(banks);
+ FunctionType newFuncType = FunctionType::get(
+ funcOp.getContext(), funcOp.getFunctionType().getInputs(),
+ newReturnType);
+ funcOp.setType(newFuncType);
+ returnOp->replaceAllUsesWith(newReturnOp);
+ opsToErase.push_back(returnOp);
+ }
+ }
+
+ // TODO: if use is empty, we should delete the original block args; and
+ // reset function type
+ if (memrefVal.use_empty()) {
+ if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
+ blockArg.getOwner()->eraseArgument(blockArg.getArgNumber());
+ blocksToModify.insert(blockArg.getOwner());
+ } else {
+ memrefVal.getDefiningOp()->erase();
+ }
+ }
+ }
+
+ for (auto *block : blocksToModify) {
+ if (!isa<func::FuncOp>(block->getParentOp()))
+ continue;
+ func::FuncOp funcOp = cast<func::FuncOp>(block->getParentOp());
+ SmallVector<Type, 4> newArgTypes;
+ for (BlockArgument arg : funcOp.getArguments()) {
+ newArgTypes.push_back(arg.getType());
+ }
+ FunctionType newFuncType =
+ FunctionType::get(funcOp.getContext(), newArgTypes,
+ funcOp.getFunctionType().getResults());
+ funcOp.setType(newFuncType);
+ }
+
+ /// - `isDefinedOutsideRegion` returns true if the given value is invariant
+ /// with
+ /// respect to the given region. A common implementation might be:
+ /// `value.getParentRegion()->isProperAncestor(region)`.
+
+ if (unrollFactor == 1) {
+ // TODO: how to address "expected pattern to replace the root operation" if
+ // just simply return success
+ return success();
+ }
+
+ return success();
+}
+
+void ParallelUnroll::runOnOperation() {
+ if (getOperation().isExternal()) {
+ return;
+ }
+
+ getOperation().walk([&](AffineParallelOp parOp) {
+ (void)parallelUnrollByFactor(parOp, unrollFactor);
+ return WalkResult::advance();
+ });
+ for (auto *op : opsToErase) {
+ op->erase();
+ }
+}
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::affine::createParallelUnrollPass(
+ int unrollFactor,
+ const std::function<unsigned(AffineParallelOp)> &getUnrollFactor) {
+ return std::make_unique<ParallelUnroll>(
+ unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor),
+ getUnrollFactor);
+}
>From 9bcc13ca65d6850348a9499742b86d06877a2a57 Mon Sep 17 00:00:00 2001
From: Jiahan Xie <jx353 at cornell.edu>
Date: Sun, 10 Nov 2024 13:39:01 -0500
Subject: [PATCH 2/6] use rewriter pattern partially working
---
.../Affine/Transforms/ParallelUnroll.cpp | 328 ++++++++++--------
1 file changed, 181 insertions(+), 147 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
index 62afbb0dbb1fad..6dd5e6e937e474 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
@@ -25,6 +25,8 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
@@ -73,7 +75,6 @@ struct ParallelUnroll
private:
// map from original memory definition to newly allocated banks
DenseMap<Value, SmallVector<Value>> memoryToBanks;
- SmallVector<Operation *, 8> opsToErase;
};
} // namespace
@@ -163,133 +164,191 @@ Value computeIntraBankingOffset(OpBuilder &builder, Location loc, Value address,
return offset;
}
-/// Unrolls a 'affine.parallel' op. Returns success if the loop was unrolled,
-/// failure otherwise. The default unroll factor is 4.
-LogicalResult ParallelUnroll::parallelUnrollByFactor(AffineParallelOp parOp,
- uint64_t unrollFactor) {
- // 1. identify memrefs in the parallel region,
- // 2. create memory banks for each of those memories
- // 2.1 maybe result of alloc/getglobal, etc
- // 2.2 maybe block arguments
- //
+struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
+ BankAffineLoadPattern(MLIRContext *context, uint64_t unrollFactor,
+ DenseMap<Value, SmallVector<Value>> &memoryToBanks)
+ : OpRewritePattern<AffineLoadOp>(context), unrollFactor(unrollFactor),
+ memoryToBanks(memoryToBanks) {}
+
+ LogicalResult matchAndRewrite(AffineLoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ llvm::errs() << "load pattern matchAndRewrite\n";
+ Location loc = loadOp.getLoc();
+ auto banks = memoryToBanks[loadOp.getMemref()];
+ Value loadIndex = loadOp.getIndices().front();
+ rewriter.setInsertionPointToStart(loadOp->getBlock());
+ Value bankingFactorValue =
+ rewriter.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
+ Value bankIndex = rewriter.create<mlir::arith::RemUIOp>(loc, loadIndex,
+ bankingFactorValue);
+ Value offset =
+ computeIntraBankingOffset(rewriter, loc, loadIndex, unrollFactor);
+
+ SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
+
+ SmallVector<int64_t, 4> caseValues;
+ for (unsigned i = 0; i < unrollFactor; ++i)
+ caseValues.push_back(i);
+
+ rewriter.setInsertionPoint(loadOp);
+ scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
+ loc, resultTypes, bankIndex, caseValues,
+ /*numRegions=*/unrollFactor);
- DenseSet<Value> memrefsInPar = collectMemRefs(parOp);
- Location loc = parOp.getLoc();
- OpBuilder builder(parOp);
+ for (unsigned i = 0; i < unrollFactor; ++i) {
+ Region &caseRegion = switchOp.getCaseRegions()[i];
+ rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
+ Value bankedLoad = rewriter.create<AffineLoadOp>(loc, banks[i], offset);
+ rewriter.create<scf::YieldOp>(loc, bankedLoad);
+ }
- DenseSet<Block *> blocksToModify;
- for (auto memrefVal : memrefsInPar) {
- SmallVector<Value> banks = createBanks(memrefVal, unrollFactor);
- memoryToBanks[memrefVal] = banks;
-
- for (auto *user : memrefVal.getUsers()) {
- // if user is within parallel region
- TypeSwitch<Operation *>(user)
- .Case<affine::AffineLoadOp>([&](affine::AffineLoadOp loadOp) {
- Value loadIndex = loadOp.getIndices().front();
- builder.setInsertionPointToStart(parOp.getBody());
- Value bankingFactorValue =
- builder.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
- Value bankIndex = builder.create<mlir::arith::RemUIOp>(
- loc, loadIndex, bankingFactorValue);
- Value offset = computeIntraBankingOffset(builder, loc, loadIndex,
- unrollFactor);
-
- SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
-
- SmallVector<int64_t, 4> caseValues;
- for (unsigned i = 0; i < unrollFactor; ++i)
- caseValues.push_back(i);
-
- builder.setInsertionPoint(user);
- scf::IndexSwitchOp switchOp = builder.create<scf::IndexSwitchOp>(
- loc, resultTypes, bankIndex, caseValues,
- /*numRegions=*/unrollFactor);
-
- for (unsigned i = 0; i < unrollFactor; ++i) {
- Region &caseRegion = switchOp.getCaseRegions()[i];
- builder.setInsertionPointToStart(&caseRegion.emplaceBlock());
- Value bankedLoad =
- builder.create<AffineLoadOp>(loc, banks[i], offset);
- builder.create<scf::YieldOp>(loc, bankedLoad);
- }
-
- Region &defaultRegion = switchOp.getDefaultRegion();
- assert(defaultRegion.empty() && "Default region should be empty");
- builder.setInsertionPointToStart(&defaultRegion.emplaceBlock());
-
- TypedAttr zeroAttr =
- cast<TypedAttr>(builder.getZeroAttr(loadOp.getType()));
- auto defaultValue =
- builder.create<arith::ConstantOp>(loc, zeroAttr);
- builder.create<scf::YieldOp>(loc, defaultValue.getResult());
-
- loadOp.getResult().replaceAllUsesWith(switchOp.getResult(0));
-
- user->erase();
- })
- .Case<affine::AffineStoreOp>([&](affine::AffineStoreOp storeOp) {
- Value loadIndex = storeOp.getIndices().front();
- builder.setInsertionPointToStart(parOp.getBody());
- Value bankingFactorValue =
- builder.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
- Value bankIndex = builder.create<mlir::arith::RemUIOp>(
- loc, loadIndex, bankingFactorValue);
- Value offset = computeIntraBankingOffset(builder, loc, loadIndex,
- unrollFactor);
-
- SmallVector<Type> resultTypes = {};
-
- SmallVector<int64_t, 4> caseValues;
- for (unsigned i = 0; i < unrollFactor; ++i)
- caseValues.push_back(i);
-
- builder.setInsertionPoint(user);
- scf::IndexSwitchOp switchOp = builder.create<scf::IndexSwitchOp>(
- loc, resultTypes, bankIndex, caseValues,
- /*numRegions=*/unrollFactor);
-
- for (unsigned i = 0; i < unrollFactor; ++i) {
- Region &caseRegion = switchOp.getCaseRegions()[i];
- builder.setInsertionPointToStart(&caseRegion.emplaceBlock());
- builder.create<AffineStoreOp>(loc, storeOp.getValueToStore(),
- banks[i], offset);
- builder.create<scf::YieldOp>(loc);
- }
-
- Region &defaultRegion = switchOp.getDefaultRegion();
- assert(defaultRegion.empty() && "Default region should be empty");
- builder.setInsertionPointToStart(&defaultRegion.emplaceBlock());
-
- builder.create<scf::YieldOp>(loc);
-
- user->erase();
- })
- .Default([](Operation *op) {
- op->emitWarning("Unhandled operation type");
- op->dump();
- });
+ Region &defaultRegion = switchOp.getDefaultRegion();
+ assert(defaultRegion.empty() && "Default region should be empty");
+ rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
+
+ TypedAttr zeroAttr =
+ cast<TypedAttr>(rewriter.getZeroAttr(loadOp.getType()));
+ auto defaultValue = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+ rewriter.create<scf::YieldOp>(loc, defaultValue.getResult());
+
+ loadOp.getResult().replaceAllUsesWith(switchOp.getResult(0));
+
+ rewriter.eraseOp(loadOp);
+ return success();
+ }
+
+private:
+ uint64_t unrollFactor;
+ DenseMap<Value, SmallVector<Value>> &memoryToBanks;
+};
+
+struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
+ BankAffineStorePattern(MLIRContext *context, uint64_t unrollFactor,
+ DenseMap<Value, SmallVector<Value>> &memoryToBanks)
+ : OpRewritePattern<AffineStoreOp>(context), unrollFactor(unrollFactor),
+ memoryToBanks(memoryToBanks) {}
+
+ LogicalResult matchAndRewrite(AffineStoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ llvm::errs() << "store pattern matchAndRewrite\n";
+ Location loc = storeOp.getLoc();
+ auto banks = memoryToBanks[storeOp.getMemref()];
+ Value loadIndex = storeOp.getIndices().front();
+ rewriter.setInsertionPointToStart(storeOp->getBlock());
+ Value bankingFactorValue =
+ rewriter.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
+ Value bankIndex = rewriter.create<mlir::arith::RemUIOp>(loc, loadIndex,
+ bankingFactorValue);
+ Value offset =
+ computeIntraBankingOffset(rewriter, loc, loadIndex, unrollFactor);
+
+ SmallVector<Type> resultTypes = {};
+
+ SmallVector<int64_t, 4> caseValues;
+ for (unsigned i = 0; i < unrollFactor; ++i)
+ caseValues.push_back(i);
+
+ rewriter.setInsertionPoint(storeOp);
+ scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
+ loc, resultTypes, bankIndex, caseValues,
+ /*numRegions=*/unrollFactor);
+
+ for (unsigned i = 0; i < unrollFactor; ++i) {
+ Region &caseRegion = switchOp.getCaseRegions()[i];
+ rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
+ rewriter.create<AffineStoreOp>(loc, storeOp.getValueToStore(), banks[i],
+ offset);
+ rewriter.create<scf::YieldOp>(loc);
}
- for (auto *user : memrefVal.getUsers()) {
- if (auto returnOp = dyn_cast<func::ReturnOp>(user)) {
- OpBuilder builder(returnOp);
- func::FuncOp funcOp = returnOp.getParentOp();
- builder.setInsertionPointToEnd(&funcOp.getBlocks().front());
- auto newReturnOp =
- builder.create<func::ReturnOp>(loc, ValueRange(banks));
- TypeRange newReturnType = TypeRange(banks);
- FunctionType newFuncType = FunctionType::get(
- funcOp.getContext(), funcOp.getFunctionType().getInputs(),
- newReturnType);
- funcOp.setType(newFuncType);
- returnOp->replaceAllUsesWith(newReturnOp);
- opsToErase.push_back(returnOp);
+ Region &defaultRegion = switchOp.getDefaultRegion();
+ assert(defaultRegion.empty() && "Default region should be empty");
+ rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock());
+
+ rewriter.create<scf::YieldOp>(loc);
+
+ rewriter.eraseOp(storeOp);
+ return success();
+ }
+
+private:
+ uint64_t unrollFactor;
+ DenseMap<Value, SmallVector<Value>> &memoryToBanks;
+};
+
+struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
+ BankReturnPattern(MLIRContext *context,
+ DenseMap<Value, SmallVector<Value>> &memoryToBanks)
+ : OpRewritePattern<func::ReturnOp>(context),
+ memoryToBanks(memoryToBanks) {}
+
+ LogicalResult matchAndRewrite(func::ReturnOp returnOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = returnOp.getLoc();
+ SmallVector<Value, 4> newReturnOperands;
+ bool allOrigMemsUsedByReturn = true;
+ for (auto operand : returnOp.getOperands()) {
+ if (!memoryToBanks.contains(operand)) {
+ newReturnOperands.push_back(operand);
+ continue;
}
+ if (operand.hasOneUse())
+ allOrigMemsUsedByReturn = false;
+ auto banks = memoryToBanks[operand];
+ newReturnOperands.append(banks.begin(), banks.end());
+ }
+ func::FuncOp funcOp = returnOp.getParentOp();
+ rewriter.setInsertionPointToEnd(&funcOp.getBlocks().front());
+ auto newReturnOp =
+ rewriter.create<func::ReturnOp>(loc, ValueRange(newReturnOperands));
+ TypeRange newReturnType = TypeRange(newReturnOperands);
+ FunctionType newFuncType =
+ FunctionType::get(funcOp.getContext(),
+ funcOp.getFunctionType().getInputs(), newReturnType);
+ funcOp.setType(newFuncType);
+
+ if (allOrigMemsUsedByReturn) {
+ rewriter.replaceOp(returnOp, newReturnOp);
}
+ return success();
+ }
- // TODO: if use is empty, we should delete the original block args; and
- // reset function type
+private:
+ DenseMap<Value, SmallVector<Value>> &memoryToBanks;
+};
+
+void ParallelUnroll::runOnOperation() {
+ if (getOperation().isExternal()) {
+ return;
+ }
+
+ getOperation().walk([&](AffineParallelOp parOp) {
+ DenseSet<Value> memrefsInPar = collectMemRefs(parOp);
+
+ for (auto memrefVal : memrefsInPar) {
+ SmallVector<Value> banks = createBanks(memrefVal, unrollFactor);
+ memoryToBanks[memrefVal] = banks;
+ }
+ });
+
+ auto *ctx = &getContext();
+
+ RewritePatternSet patterns(ctx);
+
+ patterns.add<BankAffineLoadPattern>(ctx, unrollFactor, memoryToBanks);
+ patterns.add<BankAffineStorePattern>(ctx, unrollFactor, memoryToBanks);
+ patterns.add<BankReturnPattern>(ctx, memoryToBanks);
+
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config))) {
+ signalPassFailure();
+ }
+
+ DenseSet<Block *> blocksToModify;
+ for (auto &[memrefVal, banks] : memoryToBanks) {
if (memrefVal.use_empty()) {
if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
blockArg.getOwner()->eraseArgument(blockArg.getArgNumber());
@@ -314,32 +373,7 @@ LogicalResult ParallelUnroll::parallelUnrollByFactor(AffineParallelOp parOp,
funcOp.setType(newFuncType);
}
- /// - `isDefinedOutsideRegion` returns true if the given value is invariant
- /// with
- /// respect to the given region. A common implementation might be:
- /// `value.getParentRegion()->isProperAncestor(region)`.
-
- if (unrollFactor == 1) {
- // TODO: how to address "expected pattern to replace the root operation" if
- // just simply return success
- return success();
- }
-
- return success();
-}
-
-void ParallelUnroll::runOnOperation() {
- if (getOperation().isExternal()) {
- return;
- }
-
- getOperation().walk([&](AffineParallelOp parOp) {
- (void)parallelUnrollByFactor(parOp, unrollFactor);
- return WalkResult::advance();
- });
- for (auto *op : opsToErase) {
- op->erase();
- }
+ getOperation().dump();
}
std::unique_ptr<OperationPass<func::FuncOp>>
>From b65b2df300720ae7ece6d29b9d2e742d67b669a3 Mon Sep 17 00:00:00 2001
From: Jiahan Xie <jx353 at cornell.edu>
Date: Mon, 11 Nov 2024 12:44:55 -0500
Subject: [PATCH 3/6] use affine map to load
---
.../Affine/Transforms/ParallelUnroll.cpp | 33 ++++++++++---------
1 file changed, 18 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
index 6dd5e6e937e474..a3072379d50058 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
@@ -176,13 +176,14 @@ struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
Location loc = loadOp.getLoc();
auto banks = memoryToBanks[loadOp.getMemref()];
Value loadIndex = loadOp.getIndices().front();
- rewriter.setInsertionPointToStart(loadOp->getBlock());
- Value bankingFactorValue =
- rewriter.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
- Value bankIndex = rewriter.create<mlir::arith::RemUIOp>(loc, loadIndex,
- bankingFactorValue);
- Value offset =
- computeIntraBankingOffset(rewriter, loc, loadIndex, unrollFactor);
+ auto modMap =
+ AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % unrollFactor});
+ auto divMap = AffineMap::get(
+ 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(unrollFactor)});
+
+ Value bankIndex = rewriter.create<AffineApplyOp>(
+ loc, modMap, loadIndex); // assuming one-dim
+ Value offset = rewriter.create<AffineApplyOp>(loc, divMap, loadIndex);
SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
@@ -233,14 +234,16 @@ struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
llvm::errs() << "store pattern matchAndRewrite\n";
Location loc = storeOp.getLoc();
auto banks = memoryToBanks[storeOp.getMemref()];
- Value loadIndex = storeOp.getIndices().front();
- rewriter.setInsertionPointToStart(storeOp->getBlock());
- Value bankingFactorValue =
- rewriter.create<mlir::arith::ConstantIndexOp>(loc, unrollFactor);
- Value bankIndex = rewriter.create<mlir::arith::RemUIOp>(loc, loadIndex,
- bankingFactorValue);
- Value offset =
- computeIntraBankingOffset(rewriter, loc, loadIndex, unrollFactor);
+ Value storeIndex = storeOp.getIndices().front();
+
+ auto modMap =
+ AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % unrollFactor});
+ auto divMap = AffineMap::get(
+ 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(unrollFactor)});
+
+ Value bankIndex = rewriter.create<AffineApplyOp>(
+ loc, modMap, storeIndex); // assuming one-dim
+ Value offset = rewriter.create<AffineApplyOp>(loc, divMap, storeIndex);
SmallVector<Type> resultTypes = {};
>From 9976553d52631e32f13270200baeab9712e3d336 Mon Sep 17 00:00:00 2001
From: Jiahan Xie <jx353 at cornell.edu>
Date: Mon, 11 Nov 2024 13:25:10 -0500
Subject: [PATCH 4/6] clean up old memrefs
---
.../Affine/Transforms/ParallelUnroll.cpp | 79 +++++++++----------
1 file changed, 37 insertions(+), 42 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
index a3072379d50058..41cf46a1603020 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
@@ -100,7 +100,6 @@ MemRefType computeBankedMemRefType(MemRefType originalType,
SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
assert(newShape.front() % bankingFactor == 0 &&
"memref shape must be divided by the banking factor");
- // Now assuming banking the last dimension
newShape.front() /= bankingFactor;
MemRefType newMemRefType =
MemRefType::get(newShape, originalType.getElementType(),
@@ -153,17 +152,6 @@ SmallVector<Value> createBanks(Value originalMem, uint64_t unrollFactor) {
return banks;
}
-Value computeIntraBankingOffset(OpBuilder &builder, Location loc, Value address,
- uint availableBanks) {
- Value availBanksVal =
- builder
- .create<arith::ConstantOp>(loc, builder.getIndexAttr(availableBanks))
- .getResult();
- Value offset =
- builder.create<arith::DivUIOp>(loc, address, availBanksVal).getResult();
- return offset;
-}
-
struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
BankAffineLoadPattern(MLIRContext *context, uint64_t unrollFactor,
DenseMap<Value, SmallVector<Value>> &memoryToBanks)
@@ -172,7 +160,6 @@ struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
LogicalResult matchAndRewrite(AffineLoadOp loadOp,
PatternRewriter &rewriter) const override {
- llvm::errs() << "load pattern matchAndRewrite\n";
Location loc = loadOp.getLoc();
auto banks = memoryToBanks[loadOp.getMemref()];
Value loadIndex = loadOp.getIndices().front();
@@ -231,7 +218,6 @@ struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
LogicalResult matchAndRewrite(AffineStoreOp storeOp,
PatternRewriter &rewriter) const override {
- llvm::errs() << "store pattern matchAndRewrite\n";
Location loc = storeOp.getLoc();
auto banks = memoryToBanks[storeOp.getMemref()];
Value storeIndex = storeOp.getIndices().front();
@@ -300,6 +286,7 @@ struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
auto banks = memoryToBanks[operand];
newReturnOperands.append(banks.begin(), banks.end());
}
+
func::FuncOp funcOp = returnOp.getParentOp();
rewriter.setInsertionPointToEnd(&funcOp.getBlocks().front());
auto newReturnOp =
@@ -310,9 +297,9 @@ struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
funcOp.getFunctionType().getInputs(), newReturnType);
funcOp.setType(newFuncType);
- if (allOrigMemsUsedByReturn) {
+ if (allOrigMemsUsedByReturn)
rewriter.replaceOp(returnOp, newReturnOp);
- }
+
return success();
}
@@ -320,6 +307,34 @@ struct BankReturnPattern : public OpRewritePattern<func::ReturnOp> {
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
};
+LogicalResult cleanUpOldMemRefs(DenseSet<Value> &oldMemRefVals) {
+ DenseSet<func::FuncOp> funcsToModify;
+ for (auto &memrefVal : oldMemRefVals) {
+ if (!memrefVal.use_empty())
+ continue;
+ if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
+ Block *block = blockArg.getOwner();
+ block->eraseArgument(blockArg.getArgNumber());
+ if (auto funcOp = dyn_cast<func::FuncOp>(block->getParentOp()))
+ funcsToModify.insert(funcOp);
+ } else
+ memrefVal.getDefiningOp()->erase();
+ }
+
+ // Modify the function type accordingly
+ for (auto funcOp : funcsToModify) {
+ SmallVector<Type, 4> newArgTypes;
+ for (BlockArgument arg : funcOp.getArguments()) {
+ newArgTypes.push_back(arg.getType());
+ }
+ FunctionType newFuncType =
+ FunctionType::get(funcOp.getContext(), newArgTypes,
+ funcOp.getFunctionType().getResults());
+ funcOp.setType(newFuncType);
+ }
+ return success();
+}
+
void ParallelUnroll::runOnOperation() {
if (getOperation().isExternal()) {
return;
@@ -335,7 +350,6 @@ void ParallelUnroll::runOnOperation() {
});
auto *ctx = &getContext();
-
RewritePatternSet patterns(ctx);
patterns.add<BankAffineLoadPattern>(ctx, unrollFactor, memoryToBanks);
@@ -350,33 +364,14 @@ void ParallelUnroll::runOnOperation() {
signalPassFailure();
}
- DenseSet<Block *> blocksToModify;
- for (auto &[memrefVal, banks] : memoryToBanks) {
- if (memrefVal.use_empty()) {
- if (auto blockArg = dyn_cast<BlockArgument>(memrefVal)) {
- blockArg.getOwner()->eraseArgument(blockArg.getArgNumber());
- blocksToModify.insert(blockArg.getOwner());
- } else {
- memrefVal.getDefiningOp()->erase();
- }
- }
- }
+ // Clean up the old memref values
+ DenseSet<Value> oldMemRefVals;
+ for (const auto &pair : memoryToBanks)
+ oldMemRefVals.insert(pair.first);
- for (auto *block : blocksToModify) {
- if (!isa<func::FuncOp>(block->getParentOp()))
- continue;
- func::FuncOp funcOp = cast<func::FuncOp>(block->getParentOp());
- SmallVector<Type, 4> newArgTypes;
- for (BlockArgument arg : funcOp.getArguments()) {
- newArgTypes.push_back(arg.getType());
- }
- FunctionType newFuncType =
- FunctionType::get(funcOp.getContext(), newArgTypes,
- funcOp.getFunctionType().getResults());
- funcOp.setType(newFuncType);
+ if (failed(cleanUpOldMemRefs(oldMemRefVals))) {
+ signalPassFailure();
}
-
- getOperation().dump();
}
std::unique_ptr<OperationPass<func::FuncOp>>
>From 9483cb91fe3f845081b17884bd795227a0344d5b Mon Sep 17 00:00:00 2001
From: Jiahan Xie <jx353 at cornell.edu>
Date: Mon, 11 Nov 2024 13:55:02 -0500
Subject: [PATCH 5/6] rename to parallel banking
---
mlir/include/mlir/Dialect/Affine/Passes.h | 7 +-
mlir/include/mlir/Dialect/Affine/Passes.td | 10 +-
.../Dialect/Affine/Transforms/CMakeLists.txt | 2 +-
...ParallelUnroll.cpp => ParallelBanking.cpp} | 109 +++++++++---------
4 files changed, 64 insertions(+), 64 deletions(-)
rename mlir/lib/Dialect/Affine/Transforms/{ParallelUnroll.cpp => ParallelBanking.cpp} (78%)
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index 53a941acff1d9e..6bc488f2d3e1e2 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -109,9 +109,12 @@ std::unique_ptr<OperationPass<func::FuncOp>> createLoopUnrollPass(
std::unique_ptr<OperationPass<func::FuncOp>>
createLoopUnrollAndJamPass(int unrollJamFactor = -1);
-std::unique_ptr<OperationPass<func::FuncOp>> createParallelUnrollPass(
+/// Creates a memory banking pass to explicitly partition the memories used
+/// inside affine parallel operations
+std::unique_ptr<OperationPass<func::FuncOp>> createParallelBankingPass(
int unrollFactor = -1,
- const std::function<unsigned(AffineParallelOp)> &getUnrollFactor = nullptr);
+ const std::function<unsigned(AffineParallelOp)> &getBankingFactor =
+ nullptr);
/// Creates a pass to pipeline explicit movement of data across levels of the
/// memory hierarchy.
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 55a3583b90a55b..6321750f4b6322 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -381,12 +381,12 @@ def AffineParallelize : Pass<"affine-parallelize", "func::FuncOp"> {
];
}
-def AffineParallelUnroll : Pass<"affine-parallel-unroll", "func::FuncOp"> {
- let summary = "Unroll affine parallel loops";
- let constructor = "mlir::affine::createParallelUnrollPass()";
+def AffineParallelBanking : Pass<"affine-parallel-banking", "func::FuncOp"> {
+ let summary = "Partition the memories used in affine parallel loops into banks";
+ let constructor = "mlir::affine::createParallelBankingPass()";
let options = [
- Option<"unrollFactor", "unroll-factor", "unsigned", /*default=*/"1",
- "Use this unroll factor for all loops being unrolled">
+ Option<"bankingFactor", "banking-factor", "unsigned", /*default=*/"1",
+ "Use this banking factor for all memories being partitioned">
];
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index c10d9e6a04f746..9c1290636ba77f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -11,7 +11,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
LoopTiling.cpp
LoopUnroll.cpp
LoopUnrollAndJam.cpp
- ParallelUnroll.cpp
+ ParallelBanking.cpp
PipelineDataTransfer.cpp
ReifyValueBounds.cpp
SuperVectorize.cpp
diff --git a/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/ParallelBanking.cpp
similarity index 78%
rename from mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
rename to mlir/lib/Dialect/Affine/Transforms/ParallelBanking.cpp
index 41cf46a1603020..fb49db90353dd5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ParallelUnroll.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ParallelBanking.cpp
@@ -1,4 +1,4 @@
-//===- ParallelUnroll.cpp - Code to perform parallel loop unrolling
+//===- ParallelBanking.cpp - Code to perform memory bnaking in parallel loops
//--------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -7,7 +7,7 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements parallel loop unrolling.
+// This file implements parallel loop memory banking.
//
//===----------------------------------------------------------------------===//
@@ -22,46 +22,41 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/raw_ostream.h"
#include <cassert>
namespace mlir {
namespace affine {
-#define GEN_PASS_DEF_AFFINEPARALLELUNROLL
+#define GEN_PASS_DEF_AFFINEPARALLELBANKING
#include "mlir/Dialect/Affine/Passes.h.inc"
} // namespace affine
} // namespace mlir
-#define DEBUG_TYPE "affine-parallel-unroll"
+#define DEBUG_TYPE "affine-parallel-banking"
using namespace mlir;
using namespace mlir::affine;
namespace {
-/// Unroll an `affine.parallel` operation by the `unrollFactor` specified in the
-/// attribute. Evenly splitting `memref`s that are present in the `parallel`
-/// region into smaller banks.
-struct ParallelUnroll
- : public affine::impl::AffineParallelUnrollBase<ParallelUnroll> {
- const std::function<unsigned(AffineParallelOp)> getUnrollFactor;
- ParallelUnroll() : getUnrollFactor(nullptr) {}
- ParallelUnroll(const ParallelUnroll &other) = default;
- explicit ParallelUnroll(std::optional<unsigned> unrollFactor = std::nullopt,
- const std::function<unsigned(AffineParallelOp)>
- &getUnrollFactor = nullptr)
- : getUnrollFactor(getUnrollFactor) {
- if (unrollFactor)
- this->unrollFactor = *unrollFactor;
+/// Partition memories used in `affine.parallel` operation by the
+/// `bankingFactor` throughout the program.
+struct ParallelBanking
+ : public affine::impl::AffineParallelBankingBase<ParallelBanking> {
+ const std::function<unsigned(AffineParallelOp)> getBankingFactor;
+ ParallelBanking() : getBankingFactor(nullptr) {}
+ ParallelBanking(const ParallelBanking &other) = default;
+ explicit ParallelBanking(std::optional<unsigned> bankingFactor = std::nullopt,
+ const std::function<unsigned(AffineParallelOp)>
+ &getBankingFactor = nullptr)
+ : getBankingFactor(getBankingFactor) {
+ if (bankingFactor)
+ this->bankingFactor = *bankingFactor;
}
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -69,8 +64,8 @@ struct ParallelUnroll
}
void runOnOperation() override;
- LogicalResult parallelUnrollByFactor(AffineParallelOp parOp,
- uint64_t unrollFactor);
+ LogicalResult parallelBankingByFactor(AffineParallelOp parOp,
+ uint64_t bankingFactor);
private:
// map from original memory definition to newly allocated banks
@@ -108,22 +103,23 @@ MemRefType computeBankedMemRefType(MemRefType originalType,
return newMemRefType;
}
-SmallVector<Value> createBanks(Value originalMem, uint64_t unrollFactor) {
+SmallVector<Value> createBanks(Value originalMem, uint64_t bankingFactor) {
MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
MemRefType newMemRefType =
- computeBankedMemRefType(originalMemRefType, unrollFactor);
+ computeBankedMemRefType(originalMemRefType, bankingFactor);
SmallVector<Value, 4> banks;
if (auto blockArgMem = dyn_cast<BlockArgument>(originalMem)) {
Block *block = blockArgMem.getOwner();
unsigned blockArgNum = blockArgMem.getArgNumber();
SmallVector<Type> banksType;
- for (unsigned i = 0; i < unrollFactor; ++i) {
+ for (unsigned i = 0; i < bankingFactor; ++i) {
block->insertArgument(blockArgNum + 1 + i, newMemRefType,
blockArgMem.getLoc());
}
- auto blockArgs = block->getArguments().slice(blockArgNum + 1, unrollFactor);
+ auto blockArgs =
+ block->getArguments().slice(blockArgNum + 1, bankingFactor);
banks.append(blockArgs.begin(), blockArgs.end());
} else {
Operation *originalDef = originalMem.getDefiningOp();
@@ -132,14 +128,14 @@ SmallVector<Value> createBanks(Value originalMem, uint64_t unrollFactor) {
builder.setInsertionPointAfter(originalDef);
TypeSwitch<Operation *>(originalDef)
.Case<memref::AllocOp>([&](memref::AllocOp allocOp) {
- for (uint bankCnt = 0; bankCnt < unrollFactor; bankCnt++) {
+ for (uint bankCnt = 0; bankCnt < bankingFactor; bankCnt++) {
auto bankAllocOp =
builder.create<memref::AllocOp>(loc, newMemRefType);
banks.push_back(bankAllocOp);
}
})
.Case<memref::AllocaOp>([&](memref::AllocaOp allocaOp) {
- for (uint bankCnt = 0; bankCnt < unrollFactor; bankCnt++) {
+ for (uint bankCnt = 0; bankCnt < bankingFactor; bankCnt++) {
auto bankAllocaOp =
builder.create<memref::AllocaOp>(loc, newMemRefType);
banks.push_back(bankAllocaOp);
@@ -153,9 +149,9 @@ SmallVector<Value> createBanks(Value originalMem, uint64_t unrollFactor) {
}
struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
- BankAffineLoadPattern(MLIRContext *context, uint64_t unrollFactor,
+ BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor,
DenseMap<Value, SmallVector<Value>> &memoryToBanks)
- : OpRewritePattern<AffineLoadOp>(context), unrollFactor(unrollFactor),
+ : OpRewritePattern<AffineLoadOp>(context), bankingFactor(bankingFactor),
memoryToBanks(memoryToBanks) {}
LogicalResult matchAndRewrite(AffineLoadOp loadOp,
@@ -164,9 +160,9 @@ struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
auto banks = memoryToBanks[loadOp.getMemref()];
Value loadIndex = loadOp.getIndices().front();
auto modMap =
- AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % unrollFactor});
+ AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor});
auto divMap = AffineMap::get(
- 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(unrollFactor)});
+ 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)});
Value bankIndex = rewriter.create<AffineApplyOp>(
loc, modMap, loadIndex); // assuming one-dim
@@ -175,15 +171,15 @@ struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
SmallVector<Type> resultTypes = {loadOp.getResult().getType()};
SmallVector<int64_t, 4> caseValues;
- for (unsigned i = 0; i < unrollFactor; ++i)
+ for (unsigned i = 0; i < bankingFactor; ++i)
caseValues.push_back(i);
rewriter.setInsertionPoint(loadOp);
scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
loc, resultTypes, bankIndex, caseValues,
- /*numRegions=*/unrollFactor);
+ /*numRegions=*/bankingFactor);
- for (unsigned i = 0; i < unrollFactor; ++i) {
+ for (unsigned i = 0; i < bankingFactor; ++i) {
Region &caseRegion = switchOp.getCaseRegions()[i];
rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
Value bankedLoad = rewriter.create<AffineLoadOp>(loc, banks[i], offset);
@@ -206,14 +202,14 @@ struct BankAffineLoadPattern : public OpRewritePattern<AffineLoadOp> {
}
private:
- uint64_t unrollFactor;
+ uint64_t bankingFactor;
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
};
struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
- BankAffineStorePattern(MLIRContext *context, uint64_t unrollFactor,
+ BankAffineStorePattern(MLIRContext *context, uint64_t bankingFactor,
DenseMap<Value, SmallVector<Value>> &memoryToBanks)
- : OpRewritePattern<AffineStoreOp>(context), unrollFactor(unrollFactor),
+ : OpRewritePattern<AffineStoreOp>(context), bankingFactor(bankingFactor),
memoryToBanks(memoryToBanks) {}
LogicalResult matchAndRewrite(AffineStoreOp storeOp,
@@ -223,9 +219,9 @@ struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
Value storeIndex = storeOp.getIndices().front();
auto modMap =
- AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % unrollFactor});
+ AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor});
auto divMap = AffineMap::get(
- 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(unrollFactor)});
+ 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)});
Value bankIndex = rewriter.create<AffineApplyOp>(
loc, modMap, storeIndex); // assuming one-dim
@@ -234,15 +230,15 @@ struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
SmallVector<Type> resultTypes = {};
SmallVector<int64_t, 4> caseValues;
- for (unsigned i = 0; i < unrollFactor; ++i)
+ for (unsigned i = 0; i < bankingFactor; ++i)
caseValues.push_back(i);
rewriter.setInsertionPoint(storeOp);
scf::IndexSwitchOp switchOp = rewriter.create<scf::IndexSwitchOp>(
loc, resultTypes, bankIndex, caseValues,
- /*numRegions=*/unrollFactor);
+ /*numRegions=*/bankingFactor);
- for (unsigned i = 0; i < unrollFactor; ++i) {
+ for (unsigned i = 0; i < bankingFactor; ++i) {
Region &caseRegion = switchOp.getCaseRegions()[i];
rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
rewriter.create<AffineStoreOp>(loc, storeOp.getValueToStore(), banks[i],
@@ -261,7 +257,7 @@ struct BankAffineStorePattern : public OpRewritePattern<AffineStoreOp> {
}
private:
- uint64_t unrollFactor;
+ uint64_t bankingFactor;
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
};
@@ -335,7 +331,7 @@ LogicalResult cleanUpOldMemRefs(DenseSet<Value> &oldMemRefVals) {
return success();
}
-void ParallelUnroll::runOnOperation() {
+void ParallelBanking::runOnOperation() {
if (getOperation().isExternal()) {
return;
}
@@ -344,7 +340,7 @@ void ParallelUnroll::runOnOperation() {
DenseSet<Value> memrefsInPar = collectMemRefs(parOp);
for (auto memrefVal : memrefsInPar) {
- SmallVector<Value> banks = createBanks(memrefVal, unrollFactor);
+ SmallVector<Value> banks = createBanks(memrefVal, bankingFactor);
memoryToBanks[memrefVal] = banks;
}
});
@@ -352,8 +348,8 @@ void ParallelUnroll::runOnOperation() {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- patterns.add<BankAffineLoadPattern>(ctx, unrollFactor, memoryToBanks);
- patterns.add<BankAffineStorePattern>(ctx, unrollFactor, memoryToBanks);
+ patterns.add<BankAffineLoadPattern>(ctx, bankingFactor, memoryToBanks);
+ patterns.add<BankAffineStorePattern>(ctx, bankingFactor, memoryToBanks);
patterns.add<BankReturnPattern>(ctx, memoryToBanks);
GreedyRewriteConfig config;
@@ -375,10 +371,11 @@ void ParallelUnroll::runOnOperation() {
}
std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::affine::createParallelUnrollPass(
- int unrollFactor,
- const std::function<unsigned(AffineParallelOp)> &getUnrollFactor) {
- return std::make_unique<ParallelUnroll>(
- unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor),
- getUnrollFactor);
+mlir::affine::createParallelBankingPass(
+ int bankingFactor,
+ const std::function<unsigned(AffineParallelOp)> &getBankingFactor) {
+ return std::make_unique<ParallelBanking>(
+ bankingFactor == -1 ? std::nullopt
+ : std::optional<unsigned>(bankingFactor),
+ getBankingFactor);
}
>From 2086f630b100d19c080c3fc86bb47e477798514f Mon Sep 17 00:00:00 2001
From: Jiahan Xie <jx353 at cornell.edu>
Date: Mon, 11 Nov 2024 14:12:34 -0500
Subject: [PATCH 6/6] add test case for a one dimensional memory
---
.../test/Dialect/Affine/parallel-banking.mlir | 69 +++++++++++++++++++
1 file changed, 69 insertions(+)
create mode 100644 mlir/test/Dialect/Affine/parallel-banking.mlir
diff --git a/mlir/test/Dialect/Affine/parallel-banking.mlir b/mlir/test/Dialect/Affine/parallel-banking.mlir
new file mode 100644
index 00000000000000..6300871a440269
--- /dev/null
+++ b/mlir/test/Dialect/Affine/parallel-banking.mlir
@@ -0,0 +1,69 @@
+// RUN: mlir-opt %s -split-input-file -affine-parallel-banking="banking-factor=2" | FileCheck %s
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 mod 2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0) -> (d0 floordiv 2)>
+
+// CHECK-LABEL: func.func @parallel_bank_one_dim(
+// CHECK: %[[VAL_0:arg0]]: memref<4xf32>,
+// CHECK: %[[VAL_1:arg1]]: memref<4xf32>,
+// CHECK: %[[VAL_2:arg2]]: memref<4xf32>,
+// CHECK: %[[VAL_3:arg3]]: memref<4xf32>) -> (memref<4xf32>, memref<4xf32>) {
+// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<4xf32>
+// CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<4xf32>
+// CHECK: affine.parallel (%[[VAL_7:.*]]) = (0) to (8) {
+// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]])
+// CHECK: %[[VAL_9:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]])
+// CHECK: %[[VAL_10:.*]] = scf.index_switch %[[VAL_8]] -> f32
+// CHECK: case 0 {
+// CHECK: %[[VAL_11:.*]] = affine.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<4xf32>
+// CHECK: scf.yield %[[VAL_11]] : f32
+// CHECK: }
+// CHECK: case 1 {
+// CHECK: %[[VAL_12:.*]] = affine.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<4xf32>
+// CHECK: scf.yield %[[VAL_12]] : f32
+// CHECK: }
+// CHECK: default {
+// CHECK: scf.yield %[[VAL_4]] : f32
+// CHECK: }
+// CHECK: %[[VAL_13:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]])
+// CHECK: %[[VAL_14:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]])
+// CHECK: %[[VAL_15:.*]] = scf.index_switch %[[VAL_13]] -> f32
+// CHECK: case 0 {
+// CHECK: %[[VAL_16:.*]] = affine.load %[[VAL_2]]{{\[}}%[[VAL_14]]] : memref<4xf32>
+// CHECK: scf.yield %[[VAL_16]] : f32
+// CHECK: }
+// CHECK: case 1 {
+// CHECK: %[[VAL_17:.*]] = affine.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref<4xf32>
+// CHECK: scf.yield %[[VAL_17]] : f32
+// CHECK: }
+// CHECK: default {
+// CHECK: scf.yield %[[VAL_4]] : f32
+// CHECK: }
+// CHECK: %[[VAL_18:.*]] = arith.mulf %[[VAL_10]], %[[VAL_15]] : f32
+// CHECK: %[[VAL_19:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]])
+// CHECK: %[[VAL_20:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]])
+// CHECK: scf.index_switch %[[VAL_19]]
+// CHECK: case 0 {
+// CHECK: affine.store %[[VAL_18]], %[[VAL_5]]{{\[}}%[[VAL_20]]] : memref<4xf32>
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: case 1 {
+// CHECK: affine.store %[[VAL_18]], %[[VAL_6]]{{\[}}%[[VAL_20]]] : memref<4xf32>
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: default {
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[VAL_5]], %[[VAL_6]] : memref<4xf32>, memref<4xf32>
+// CHECK: }
+func.func @parallel_bank_one_dim(%arg0: memref<8xf32>, %arg1: memref<8xf32>) -> (memref<8xf32>) {
+ %mem = memref.alloc() : memref<8xf32>
+ affine.parallel (%i) = (0) to (8) {
+ %1 = affine.load %arg0[%i] : memref<8xf32>
+ %2 = affine.load %arg1[%i] : memref<8xf32>
+ %3 = arith.mulf %1, %2 : f32
+ affine.store %3, %mem[%i] : memref<8xf32>
+ }
+ return %mem : memref<8xf32>
+}
More information about the Mlir-commits
mailing list