[llvm-branch-commits] [mlir] 4f5c304 - [MLIR] Add to enforce perfect nests while loop fusion

Vinayaka Bandishti via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Nov 5 03:30:32 PDT 2021


Author: Prateek Gupta
Date: 2021-09-27T22:46:32+05:30
New Revision: 4f5c3045514341cd5e98bc7b84dd05d49ef37cb7

URL: https://github.com/llvm/llvm-project/commit/4f5c3045514341cd5e98bc7b84dd05d49ef37cb7
DIFF: https://github.com/llvm/llvm-project/commit/4f5c3045514341cd5e98bc7b84dd05d49ef37cb7.diff

LOG: [MLIR] Add to enforce perfect nests while loop fusion

This commit adds an option to the affine-cs pipeline to enforce perfect
nests while loop fusion. Relevant test cases are also added.

Signed-Off-By: Prateek Gupta <prateek at polymagelabs.com>

Changes while porting to upstream by Vinayaka Bandishti:
1. Remove test case of reduction fusion into pointwise, since this was
being prevented due to an incorrect reason.

Added: 
    mlir/test/Transforms/affine-cs-pipeline-perfect-nest.mlir

Modified: 
    mlir/include/mlir/Transforms/Passes.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Transforms/AffineCSPipeline.cpp
    mlir/lib/Transforms/LoopFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 4296fcdc590c..3d402eb4f61c 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -26,6 +26,9 @@ class AffineForOp;
 class GreedyRewriteConfig;
 class OpPassManager;
 
+// Cerebras specific options.
+struct AffineCSPipelineOptions;
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
@@ -75,11 +78,11 @@ std::unique_ptr<Pass> createCSEPass();
 
 /// Creates a loop fusion pass which fuses loops. Buffers of size less than or
 /// equal to `localBufSizeThreshold` are promoted to memory space
-/// `fastMemorySpace'.
-std::unique_ptr<OperationPass<FuncOp>>
-createLoopFusionPass(unsigned fastMemorySpace = 0,
-                     uint64_t localBufSizeThreshold = 0,
-                     bool maximalFusion = false);
+/// `fastMemorySpace'. If `enforcePerfectNest` is true, perfect nesting is
+/// enforced while fusing.
+std::unique_ptr<OperationPass<FuncOp>> createLoopFusionPass(
+    unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0,
+    bool maximalFusion = false, bool enforcePerfectNest = false);
 
 /// Creates a loop invariant code motion pass that hoists loop invariant
 /// instructions out of the loop.
@@ -143,7 +146,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
 
 /// Creates an affine optimization pipeline including fusion and other
 /// complementary passes.
-void createAffineCSPipeline(OpPassManager &pm);
+void createAffineCSPipeline(OpPassManager &pm,
+                            const AffineCSPipelineOptions &options);
 
 } // end namespace mlir
 

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 91af2a2c56a9..479b811cbfc0 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -136,6 +136,8 @@ def AffineLoopFusion : FunctionPass<"affine-loop-fusion"> {
                             "to fast memory space">,
     Option<"maximalFusion", "fusion-maximal", "bool", /*default=*/"false",
            "Enables maximal loop fusion">,
+    Option<"enforcePerfectNest", "enforce-perfect-nest", "bool", /*default=*/"false",
+           "Enforces perfect nesting while performing loop fusion.">
   ];
   let dependentDialects = ["memref::MemRefDialect"];
 }

diff  --git a/mlir/lib/Transforms/AffineCSPipeline.cpp b/mlir/lib/Transforms/AffineCSPipeline.cpp
index c6994188b0fd..340e4ad50c5b 100644
--- a/mlir/lib/Transforms/AffineCSPipeline.cpp
+++ b/mlir/lib/Transforms/AffineCSPipeline.cpp
@@ -17,19 +17,32 @@
 #include "mlir/Transforms/Passes.h"
 
 namespace mlir {
-void createAffineCSPipeline(OpPassManager &pm) {
+struct AffineCSPipelineOptions
+    : public PassPipelineOptions<AffineCSPipelineOptions> {
+  Option<bool> enforce_perfect_nest{
+      *this, "enforce-perfect-nest",
+      llvm::cl::desc("Enables enforcement of perfect nesting while performing "
+                     "loop fusion."),
+      llvm::cl::init(false)};
+};
+
+void createAffineCSPipeline(OpPassManager &pm,
+                            const AffineCSPipelineOptions &options) {
   pm.addPass(mlir::createCanonicalizerPass());
-  pm.addPass(mlir::createLoopFusionPass(/*fastMemorySpace=*/0,
-                                        /*localBufSizeThreshold=*/0,
-                                        /*maximalFusion=*/true));
+  pm.addPass(mlir::createLoopFusionPass(
+      /*fastMemorySpace=*/0,
+      /*localBufSizeThreshold=*/0,
+      /*maximalFusion=*/true,
+      /*enforcePerfectNest=*/options.enforce_perfect_nest));
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addPass(mlir::createAffineScalarReplacementPass());
   pm.addPass(mlir::createCanonicalizerPass());
 }
+
 void registerAffineCSPipeline() {
-  mlir::PassPipelineRegistration<>(
+  mlir::PassPipelineRegistration<AffineCSPipelineOptions>(
       "affine-cs-pipeline",
-      "runs all passes for performing affine loop fusion and other "
+      "runs passes to perform affine loop fusion and other "
       "complimentary optimizations.",
       createAffineCSPipeline);
 }

diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index c19c887a593d..5bedaeafc99f 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -49,10 +49,11 @@ namespace {
 struct LoopFusion : public AffineLoopFusionBase<LoopFusion> {
   LoopFusion() = default;
   LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
-             bool maximalFusion) {
+             bool maximalFusion, bool enforcePerfectNest) {
     this->fastMemorySpace = fastMemorySpace;
     this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
     this->maximalFusion = maximalFusion;
+    this->enforcePerfectNest = enforcePerfectNest;
   }
 
   void runOnFunction() override;
@@ -62,9 +63,10 @@ struct LoopFusion : public AffineLoopFusionBase<LoopFusion> {
 
 std::unique_ptr<OperationPass<FuncOp>>
 mlir::createLoopFusionPass(unsigned fastMemorySpace,
-                           uint64_t localBufSizeThreshold, bool maximalFusion) {
+                           uint64_t localBufSizeThreshold, bool maximalFusion,
+                           bool enforcePerfectNest) {
   return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
-                                      maximalFusion);
+                                      maximalFusion, enforcePerfectNest);
 }
 
 namespace {
@@ -1397,16 +1399,18 @@ struct GreedyFusion {
   //    unique consumer.
   // *) Second pass fuses sibling nodes which share no dependence edges.
   // *) Third pass fuses any remaining producer nodes into their users.
-  void run() {
+  void run(bool enforcePerfectNest = false) {
     // TODO: Run this repeatedly until a fixed-point is reached.
-    fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
+    fuseProducerConsumerNodes(/*maxSrcUserCount=*/1, enforcePerfectNest);
     fuseSiblingNodes();
     fuseProducerConsumerNodes(
-        /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
+        /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max(),
+        enforcePerfectNest);
     eraseUnusedMemRefAllocations();
   }
 
-  void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
+  void fuseProducerConsumerNodes(unsigned maxSrcUserCount,
+                                 bool enforcePerfectNest) {
     LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
     init();
     while (!worklist.empty()) {
@@ -1606,10 +1610,18 @@ struct GreedyFusion {
             privateMemrefs.insert(memref);
           }
 
+          // To preserve perfect nesting, the required destination loop depth
+          // must be the depth of the load operation.
+          if (enforcePerfectNest) {
+            if (bestDstLoopDepth != dstLoopDepthTest) {
+              continue;
+            }
+          }
+
           // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
-          fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
+          fuseLoops(srcAffineForOp, dstAffineForOp,
+                    depthSliceUnions[bestDstLoopDepth - 1]);
           dstNodeChanged = true;
-
           LLVM_DEBUG(llvm::dbgs()
                      << "Fused src loop " << srcId << " into dst loop " << dstId
                      << " at depth " << bestDstLoopDepth << ":\n"
@@ -1971,5 +1983,5 @@ void LoopFusion::runOnFunction() {
   unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
   GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
                       maximalFusion, computeToleranceThreshold);
-  fusion.run();
+  fusion.run(this->enforcePerfectNest);
 }

diff  --git a/mlir/test/Transforms/affine-cs-pipeline-perfect-nest.mlir b/mlir/test/Transforms/affine-cs-pipeline-perfect-nest.mlir
new file mode 100644
index 000000000000..a810d700a784
--- /dev/null
+++ b/mlir/test/Transforms/affine-cs-pipeline-perfect-nest.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt %s -affine-cs-pipeline="enforce-perfect-nest" | FileCheck %s
+
+// This test cases checks the enforcement of perfect nesting for simple matrix multiplication.
+func @simple_matmul_one(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16x16xf32>) {
+    %cst = constant 0.000000e+00 : f32
+    affine.for %arg3 = 0 to 16 {
+        affine.for %arg4 = 0 to 16 {
+          affine.store %cst, %arg2[%arg3, %arg4] : memref<16x16xf32>
+        }
+    }
+    affine.for %arg3 = 0 to 16 {
+      affine.for %arg4 = 0 to 16 {
+        affine.for %arg5 = 0 to 16 {
+          %0 = affine.load %arg0[%arg3, %arg5] : memref<16x16xf32>
+          %1 = affine.load %arg1[%arg5, %arg4] : memref<16x16xf32>
+          %2 = affine.load %arg2[%arg3, %arg4] : memref<16x16xf32>
+          %3 = mulf %0, %1 : f32
+          %4 = addf %3, %2 : f32
+          affine.store %4, %arg2[%arg3, %arg4] : memref<16x16xf32>
+        }
+      }
+    }
+    return
+}
+
+// CHECK-LABEL: func @simple_matmul_one
+// CHECK:       (%[[LHS:.*]]: memref<16x16xf32>, %[[RHS:.*]]: memref<16x16xf32>, %[[OUT:.*]]: memref<16x16xf32>) {
+// CHECK:           %[[INIT:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT:      affine.for %[[i:.*]] = 0 to 16 {
+// CHECK-NEXT:          affine.for %[[j:.*]] = 0 to 16 {
+// CHECK-NEXT:              affine.store %[[INIT]], %[[OUT]][%[[i]], %[[j]]]
+// CHECK:           affine.for %[[i:.*]] = 0 to 16 {
+// CHECK-NEXT:          affine.for %[[j:.*]] = 0 to 16 {
+// CHECK-NEXT:              affine.for %[[k:.*]] = 0 to 16 {
+// CHECK-NEXT:                  %[[LHS_VAL:.*]] = affine.load %[[LHS]][%[[i]], %[[k]]]
+// CHECK-NEXT:                  %[[RHS_VAL:.*]] = affine.load %[[RHS]][%[[k]], %[[j]]]
+// CHECK-NEXT:                  %[[OUT_VAL:.*]] = affine.load %[[OUT]][%[[i]], %[[j]]]
+// CHECK-NEXT:                  %[[PROD:.*]] = mulf %[[LHS_VAL]], %[[RHS_VAL]]
+// CHECK-NEXT:                  %[[RES:.*]] = addf %[[PROD]], %[[OUT_VAL]]
+// CHECK-NEXT:                  affine.store %[[RES]], %[[OUT]][%[[i]], %[[j]]]
+
+func @should_fuse_pointwise(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16x16xf32>){
+  affine.for %arg3 = 0 to 16 {
+    affine.for %arg4 = 0 to 16 {
+      %0 = affine.load %arg0[%arg3, %arg4] : memref<16x16xf32>
+      %1 = affine.load %arg1[%arg3, %arg4] : memref<16x16xf32>
+      %2 = addf %0, %1 : f32
+      affine.store %2, %arg0[%arg3, %arg4] : memref<16x16xf32>
+    }
+  }
+  affine.for %arg3 = 0 to 16 {
+    affine.for %arg4 = 0 to 16 {
+      affine.for %arg5 = 0 to 16 {
+        %0 = affine.load %arg0[%arg3, %arg5] : memref<16x16xf32>
+        %1 = affine.load %arg1[%arg5, %arg4] : memref<16x16xf32>
+        %2 = affine.load %arg2[%arg3, %arg4] : memref<16x16xf32>
+        %3 = mulf %0, %1 : f32
+        %4 = addf %3, %2 : f32
+        affine.store %4, %arg2[%arg3, %arg4] : memref<16x16xf32>
+      }
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @should_fuse_pointwise
+// CHECK:         affine.for
+// CHECK-NEXT:      affine.for
+// CHECK-NEXT:        affine.for
+// CHECK-NEXT:          {{.*}} = affine.load
+// CHECK-NEXT:          {{.*}} = affine.load
+// CHECK-NEXT:          {{.*}} = addf
+// CHECK-NEXT:         affine.store
+// CHECK-NEXT:          {{.*}} = affine.load
+// CHECK-NEXT:          {{.*}} = affine.load
+// CHECK-NEXT:          {{.*}} = mulf
+// CHECK-NEXT:          {{.*}} = addf
+// CHHECK-NEXT:         affine.store          


        


More information about the llvm-branch-commits mailing list