[llvm-branch-commits] [mlir] 4f5c304 - [MLIR] Add to enforce perfect nests while loop fusion
Vinayaka Bandishti via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Nov 5 03:30:32 PDT 2021
Author: Prateek Gupta
Date: 2021-09-27T22:46:32+05:30
New Revision: 4f5c3045514341cd5e98bc7b84dd05d49ef37cb7
URL: https://github.com/llvm/llvm-project/commit/4f5c3045514341cd5e98bc7b84dd05d49ef37cb7
DIFF: https://github.com/llvm/llvm-project/commit/4f5c3045514341cd5e98bc7b84dd05d49ef37cb7.diff
LOG: [MLIR] Add to enforce perfect nests while loop fusion
This commit adds an option to the affine-cs pipeline to enforce perfect
nests while loop fusion. Relevant test cases are also added.
Signed-Off-By: Prateek Gupta <prateek at polymagelabs.com>
Changes while porting to upstream by Vinayaka Bandishti:
1. Remove test case of reduction fusion into pointwise, since this was
being prevented due to an incorrect reason.
Added:
mlir/test/Transforms/affine-cs-pipeline-perfect-nest.mlir
Modified:
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Transforms/AffineCSPipeline.cpp
mlir/lib/Transforms/LoopFusion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 4296fcdc590c..3d402eb4f61c 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -26,6 +26,9 @@ class AffineForOp;
class GreedyRewriteConfig;
class OpPassManager;
+// Cerebras specific options.
+struct AffineCSPipelineOptions;
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
@@ -75,11 +78,11 @@ std::unique_ptr<Pass> createCSEPass();
/// Creates a loop fusion pass which fuses loops. Buffers of size less than or
/// equal to `localBufSizeThreshold` are promoted to memory space
-/// `fastMemorySpace'.
-std::unique_ptr<OperationPass<FuncOp>>
-createLoopFusionPass(unsigned fastMemorySpace = 0,
- uint64_t localBufSizeThreshold = 0,
- bool maximalFusion = false);
+/// `fastMemorySpace'. If `enforcePerfectNest` is true, perfect nesting is
+/// enforced while fusing.
+std::unique_ptr<OperationPass<FuncOp>> createLoopFusionPass(
+ unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0,
+ bool maximalFusion = false, bool enforcePerfectNest = false);
/// Creates a loop invariant code motion pass that hoists loop invariant
/// instructions out of the loop.
@@ -143,7 +146,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
/// Creates an affine optimization pipeline including fusion and other
/// complementary passes.
-void createAffineCSPipeline(OpPassManager &pm);
+void createAffineCSPipeline(OpPassManager &pm,
+ const AffineCSPipelineOptions &options);
} // end namespace mlir
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 91af2a2c56a9..479b811cbfc0 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -136,6 +136,8 @@ def AffineLoopFusion : FunctionPass<"affine-loop-fusion"> {
"to fast memory space">,
Option<"maximalFusion", "fusion-maximal", "bool", /*default=*/"false",
"Enables maximal loop fusion">,
+ Option<"enforcePerfectNest", "enforce-perfect-nest", "bool", /*default=*/"false",
+ "Enforces perfect nesting while performing loop fusion.">
];
let dependentDialects = ["memref::MemRefDialect"];
}
diff --git a/mlir/lib/Transforms/AffineCSPipeline.cpp b/mlir/lib/Transforms/AffineCSPipeline.cpp
index c6994188b0fd..340e4ad50c5b 100644
--- a/mlir/lib/Transforms/AffineCSPipeline.cpp
+++ b/mlir/lib/Transforms/AffineCSPipeline.cpp
@@ -17,19 +17,32 @@
#include "mlir/Transforms/Passes.h"
namespace mlir {
-void createAffineCSPipeline(OpPassManager &pm) {
+struct AffineCSPipelineOptions
+ : public PassPipelineOptions<AffineCSPipelineOptions> {
+ Option<bool> enforce_perfect_nest{
+ *this, "enforce-perfect-nest",
+ llvm::cl::desc("Enables enforcement of perfect nesting while performing "
+ "loop fusion."),
+ llvm::cl::init(false)};
+};
+
+void createAffineCSPipeline(OpPassManager &pm,
+ const AffineCSPipelineOptions &options) {
pm.addPass(mlir::createCanonicalizerPass());
- pm.addPass(mlir::createLoopFusionPass(/*fastMemorySpace=*/0,
- /*localBufSizeThreshold=*/0,
- /*maximalFusion=*/true));
+ pm.addPass(mlir::createLoopFusionPass(
+ /*fastMemorySpace=*/0,
+ /*localBufSizeThreshold=*/0,
+ /*maximalFusion=*/true,
+ /*enforcePerfectNest=*/options.enforce_perfect_nest));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createAffineScalarReplacementPass());
pm.addPass(mlir::createCanonicalizerPass());
}
+
void registerAffineCSPipeline() {
- mlir::PassPipelineRegistration<>(
+ mlir::PassPipelineRegistration<AffineCSPipelineOptions>(
"affine-cs-pipeline",
- "runs all passes for performing affine loop fusion and other "
+ "runs passes to perform affine loop fusion and other "
"complimentary optimizations.",
createAffineCSPipeline);
}
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index c19c887a593d..5bedaeafc99f 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -49,10 +49,11 @@ namespace {
struct LoopFusion : public AffineLoopFusionBase<LoopFusion> {
LoopFusion() = default;
LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
- bool maximalFusion) {
+ bool maximalFusion, bool enforcePerfectNest) {
this->fastMemorySpace = fastMemorySpace;
this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
this->maximalFusion = maximalFusion;
+ this->enforcePerfectNest = enforcePerfectNest;
}
void runOnFunction() override;
@@ -62,9 +63,10 @@ struct LoopFusion : public AffineLoopFusionBase<LoopFusion> {
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLoopFusionPass(unsigned fastMemorySpace,
- uint64_t localBufSizeThreshold, bool maximalFusion) {
+ uint64_t localBufSizeThreshold, bool maximalFusion,
+ bool enforcePerfectNest) {
return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
- maximalFusion);
+ maximalFusion, enforcePerfectNest);
}
namespace {
@@ -1397,16 +1399,18 @@ struct GreedyFusion {
// unique consumer.
// *) Second pass fuses sibling nodes which share no dependence edges.
// *) Third pass fuses any remaining producer nodes into their users.
- void run() {
+ void run(bool enforcePerfectNest = false) {
// TODO: Run this repeatedly until a fixed-point is reached.
- fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
+ fuseProducerConsumerNodes(/*maxSrcUserCount=*/1, enforcePerfectNest);
fuseSiblingNodes();
fuseProducerConsumerNodes(
- /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
+ /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max(),
+ enforcePerfectNest);
eraseUnusedMemRefAllocations();
}
- void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
+ void fuseProducerConsumerNodes(unsigned maxSrcUserCount,
+ bool enforcePerfectNest) {
LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
init();
while (!worklist.empty()) {
@@ -1606,10 +1610,18 @@ struct GreedyFusion {
privateMemrefs.insert(memref);
}
+ // To preserve perfect nesting, the required destination loop depth
+ // must be the depth of the load operation.
+ if (enforcePerfectNest) {
+ if (bestDstLoopDepth != dstLoopDepthTest) {
+ continue;
+ }
+ }
+
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
- fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
+ fuseLoops(srcAffineForOp, dstAffineForOp,
+ depthSliceUnions[bestDstLoopDepth - 1]);
dstNodeChanged = true;
-
LLVM_DEBUG(llvm::dbgs()
<< "Fused src loop " << srcId << " into dst loop " << dstId
<< " at depth " << bestDstLoopDepth << ":\n"
@@ -1971,5 +1983,5 @@ void LoopFusion::runOnFunction() {
unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
maximalFusion, computeToleranceThreshold);
- fusion.run();
+ fusion.run(this->enforcePerfectNest);
}
diff --git a/mlir/test/Transforms/affine-cs-pipeline-perfect-nest.mlir b/mlir/test/Transforms/affine-cs-pipeline-perfect-nest.mlir
new file mode 100644
index 000000000000..a810d700a784
--- /dev/null
+++ b/mlir/test/Transforms/affine-cs-pipeline-perfect-nest.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt %s -affine-cs-pipeline="enforce-perfect-nest" | FileCheck %s
+
+// This test cases checks the enforcement of perfect nesting for simple matrix multiplication.
+func @simple_matmul_one(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16x16xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ affine.for %arg3 = 0 to 16 {
+ affine.for %arg4 = 0 to 16 {
+ affine.store %cst, %arg2[%arg3, %arg4] : memref<16x16xf32>
+ }
+ }
+ affine.for %arg3 = 0 to 16 {
+ affine.for %arg4 = 0 to 16 {
+ affine.for %arg5 = 0 to 16 {
+ %0 = affine.load %arg0[%arg3, %arg5] : memref<16x16xf32>
+ %1 = affine.load %arg1[%arg5, %arg4] : memref<16x16xf32>
+ %2 = affine.load %arg2[%arg3, %arg4] : memref<16x16xf32>
+ %3 = mulf %0, %1 : f32
+ %4 = addf %3, %2 : f32
+ affine.store %4, %arg2[%arg3, %arg4] : memref<16x16xf32>
+ }
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @simple_matmul_one
+// CHECK: (%[[LHS:.*]]: memref<16x16xf32>, %[[RHS:.*]]: memref<16x16xf32>, %[[OUT:.*]]: memref<16x16xf32>) {
+// CHECK: %[[INIT:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: affine.for %[[i:.*]] = 0 to 16 {
+// CHECK-NEXT: affine.for %[[j:.*]] = 0 to 16 {
+// CHECK-NEXT: affine.store %[[INIT]], %[[OUT]][%[[i]], %[[j]]]
+// CHECK: affine.for %[[i:.*]] = 0 to 16 {
+// CHECK-NEXT: affine.for %[[j:.*]] = 0 to 16 {
+// CHECK-NEXT: affine.for %[[k:.*]] = 0 to 16 {
+// CHECK-NEXT: %[[LHS_VAL:.*]] = affine.load %[[LHS]][%[[i]], %[[k]]]
+// CHECK-NEXT: %[[RHS_VAL:.*]] = affine.load %[[RHS]][%[[k]], %[[j]]]
+// CHECK-NEXT: %[[OUT_VAL:.*]] = affine.load %[[OUT]][%[[i]], %[[j]]]
+// CHECK-NEXT: %[[PROD:.*]] = mulf %[[LHS_VAL]], %[[RHS_VAL]]
+// CHECK-NEXT: %[[RES:.*]] = addf %[[PROD]], %[[OUT_VAL]]
+// CHECK-NEXT: affine.store %[[RES]], %[[OUT]][%[[i]], %[[j]]]
+
+func @should_fuse_pointwise(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16x16xf32>){
+ affine.for %arg3 = 0 to 16 {
+ affine.for %arg4 = 0 to 16 {
+ %0 = affine.load %arg0[%arg3, %arg4] : memref<16x16xf32>
+ %1 = affine.load %arg1[%arg3, %arg4] : memref<16x16xf32>
+ %2 = addf %0, %1 : f32
+ affine.store %2, %arg0[%arg3, %arg4] : memref<16x16xf32>
+ }
+ }
+ affine.for %arg3 = 0 to 16 {
+ affine.for %arg4 = 0 to 16 {
+ affine.for %arg5 = 0 to 16 {
+ %0 = affine.load %arg0[%arg3, %arg5] : memref<16x16xf32>
+ %1 = affine.load %arg1[%arg5, %arg4] : memref<16x16xf32>
+ %2 = affine.load %arg2[%arg3, %arg4] : memref<16x16xf32>
+ %3 = mulf %0, %1 : f32
+ %4 = addf %3, %2 : f32
+ affine.store %4, %arg2[%arg3, %arg4] : memref<16x16xf32>
+ }
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @should_fuse_pointwise
+// CHECK: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: {{.*}} = affine.load
+// CHECK-NEXT: {{.*}} = affine.load
+// CHECK-NEXT: {{.*}} = addf
+// CHECK-NEXT: affine.store
+// CHECK-NEXT: {{.*}} = affine.load
+// CHECK-NEXT: {{.*}} = affine.load
+// CHECK-NEXT: {{.*}} = mulf
+// CHECK-NEXT: {{.*}} = addf
+// CHHECK-NEXT: affine.store
More information about the llvm-branch-commits
mailing list