[Mlir-commits] [mlir] [MLIR][Affine] Fix affine loop fusion with vector ops #115849, #120227 (PR #122799)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 13 13:57:41 PST 2025


https://github.com/brod4910 updated https://github.com/llvm/llvm-project/pull/122799

>From b1aa2c719e1809156815edf961a743fd00f64567 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Mon, 13 Jan 2025 14:44:45 -0700
Subject: [PATCH 1/3] Fix affine loop fusion with vector ops #115849, #120227

---
 .../Dialect/Affine/Transforms/LoopFusion.cpp  | 210 ++++++++++++++++--
 mlir/test/Dialect/Affine/loop-fusion-4.mlir   | 145 ++++++++++++
 2 files changed, 339 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 6fefe4487ef59a..5039162ac5bc06 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -10,6 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
 #include "mlir/Dialect/Affine/Passes.h"
 
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
@@ -23,14 +25,22 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cstdint>
 #include <iomanip>
+#include <iostream>
 #include <optional>
 #include <sstream>
 
@@ -177,13 +187,115 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
                                 producerConsumerMemrefs);
 }
 
+/// Checks the shapes of the loads and stores of each memref in
+/// the producer/consumer chains. If the load shapes are larger
+/// than the stores then we cannot fuse the loops. The loads
+/// would have a dependency on the values stored.
+static bool checkLoadStoreShapes(unsigned srcId, unsigned dstId,
+                                 DenseSet<Value> &producerConsumerMemrefs,
+                                 MemRefDependenceGraph *mdg) {
+  SmallVector<Operation *> storeOps;
+  SmallVector<Operation *> loadOps;
+
+  auto *srcNode = mdg->getNode(srcId);
+  auto *dstNode = mdg->getNode(dstId);
+
+  for (Value memref : producerConsumerMemrefs) {
+    srcNode->getStoreOpsForMemref(memref, &storeOps);
+    dstNode->getLoadOpsForMemref(memref, &loadOps);
+
+    for (Operation *storeOp : storeOps) {
+      Value storeValue =
+          cast<AffineWriteOpInterface>(storeOp).getValueToStore();
+      auto storeShapedType = dyn_cast<ShapedType>(storeValue.getType());
+
+      if (!storeShapedType)
+        continue;
+
+      for (Operation *loadOp : loadOps) {
+        Value loadValue = cast<AffineReadOpInterface>(loadOp).getValue();
+        auto loadShapedType = dyn_cast<ShapedType>(loadValue.getType());
+
+        if (!loadShapedType)
+          continue;
+
+        for (int i = 0; i < loadShapedType.getRank(); ++i) {
+          auto loadDim = loadShapedType.getDimSize(i);
+          auto storeDim = storeShapedType.getDimSize(i);
+
+          if (loadDim > storeDim)
+            return false;
+        }
+      }
+    }
+
+    storeOps.clear();
+    loadOps.clear();
+  }
+
+  return true;
+}
+
+/// Checks the shapes of the loads and stores of each memref in
+/// the producer/consumer chains. If the load shapes are larger
+/// than the stores then we cannot fuse the loops. The loads
+/// would have a dependency on the values stored.
+static bool checkVectorLoadStoreOps(unsigned srcId, unsigned dstId,
+                                    DenseSet<Value> &producerConsumerMemrefs,
+                                    MemRefDependenceGraph *mdg) {
+  SmallVector<Operation *> storeOps;
+  SmallVector<Operation *> loadOps;
+
+  auto *srcNode = mdg->getNode(srcId);
+  auto *dstNode = mdg->getNode(dstId);
+
+  for (Value memref : producerConsumerMemrefs) {
+    srcNode->getStoreOpsForMemref(memref, &storeOps);
+    dstNode->getLoadOpsForMemref(memref, &loadOps);
+
+    for (Operation *storeOp : storeOps) {
+      auto vectorStoreOp = dyn_cast<AffineVectorStoreOp>(storeOp);
+
+      if (!vectorStoreOp)
+        continue;
+
+      auto storeVecType = vectorStoreOp.getVectorType();
+
+      for (Operation *loadOp : loadOps) {
+        auto vectorLoadOp = dyn_cast<AffineVectorLoadOp>(loadOp);
+
+        if (!vectorLoadOp)
+          return false;
+
+        auto loadVecType = vectorLoadOp.getVectorType();
+
+        if (loadVecType.getRank() != storeVecType.getRank())
+          return false;
+
+        for (int i = 0; i < loadVecType.getRank(); ++i) {
+          auto loadDim = loadVecType.getDimSize(i);
+          auto storeDim = storeVecType.getDimSize(i);
+
+          if (loadDim > storeDim)
+            return false;
+        }
+      }
+    }
+
+    storeOps.clear();
+    loadOps.clear();
+  }
+
+  return true;
+}
+
 /// A memref escapes in the context of the fusion pass if either:
 ///   1. it (or its alias) is a block argument, or
 ///   2. created by an op not known to guarantee alias freedom,
-///   3. it (or its alias) are used by ops other than affine dereferencing ops
-///   (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
-///   terminator ops, etc.); such ops do not deference the memref in an affine
-///   way.
+///   3. it (or its alias) are used by ops other than affine dereferencing
+///   ops (e.g., by call op, memref load/store ops, alias creating ops,
+///   unknown ops, terminator ops, etc.); such ops do not deference the
+///   memref in an affine way.
 static bool isEscapingMemref(Value memref, Block *block) {
   Operation *defOp = memref.getDefiningOp();
   // Check if 'memref' is a block argument.
@@ -237,6 +349,57 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
   node->op = newRootForOp;
 }
 
+static Value createPrivateVectorOpMemRef(
+    AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth,
+    std::optional<unsigned> fastMemorySpace, uint64_t localBufSizeThreshold) {
+  Operation *forInst = forOp.getOperation();
+
+  // Create builder to insert alloc op just before 'forOp'.
+  OpBuilder b(forInst);
+  // Builder to create constants at the top level.
+  OpBuilder top(forInst->getParentRegion());
+  // Create new memref type based on slice bounds.
+  auto srcAffineOp = cast<AffineWriteOpInterface>(srcStoreOpInst);
+
+  auto oldMemRef = srcAffineOp.getMemRef();
+  auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
+  unsigned rank = oldMemRefType.getRank();
+
+  auto srcOpResult = srcAffineOp.getValueToStore();
+  auto shapedType = dyn_cast<ShapedType>(srcOpResult.getType());
+
+  // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
+  // by 'srcStoreOpInst'.
+  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
+  assert(eltSize && "memrefs with size elt types expected");
+  uint64_t bufSize = *eltSize * shapedType.getNumElements();
+  unsigned newMemSpace;
+  if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
+    newMemSpace = *fastMemorySpace;
+  } else {
+    newMemSpace = oldMemRefType.getMemorySpaceAsInt();
+  }
+
+  auto newMemRefType = MemRefType::get(
+      shapedType.getShape(), oldMemRefType.getElementType(), {}, newMemSpace);
+
+  // Create new private memref for fused loop 'forOp'. 'newShape' is always
+  // a constant shape.
+  Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
+
+  auto indexRemap = AffineMap::getMultiDimIdentityMap(rank, forOp.getContext());
+
+  // Replace all users of 'oldMemRef' with 'newMemRef'.
+  LogicalResult res =
+      replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
+                               /*extraOperands=*/{},
+                               /*symbolOperands=*/{},
+                               /*domOpFilter=*/&*forOp.getBody()->begin());
+  assert(succeeded(res) &&
+         "replaceAllMemrefUsesWith should always succeed here");
+  return newMemRef;
+}
+
 // Creates and returns a private (single-user) memref for fused loop rooted
 // at 'forOp', with (potentially reduced) memref size based on the
 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
@@ -306,9 +469,9 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   } else {
     newMemSpace = oldMemRefType.getMemorySpaceAsInt();
   }
+
   auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
                                        {}, newMemSpace);
-
   // Create new private memref for fused loop 'forOp'. 'newShape' is always
   // a constant shape.
   // TODO: Create/move alloc ops for private memrefs closer to their
@@ -322,7 +485,6 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   remapExprs.reserve(rank);
   for (unsigned i = 0; i < rank; i++) {
     auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
-
     auto remapExpr =
         simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
     remapExprs.push_back(remapExpr);
@@ -340,6 +502,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   assert(succeeded(res) &&
          "replaceAllMemrefUsesWith should always succeed here");
   (void)res;
+
   return newMemRef;
 }
 
@@ -516,6 +679,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     // nest slice 'slice' were to be inserted into the dst loop nest at loop
     // depth 'i'.
     MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
+
     if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
                                         &slice))) {
       LLVM_DEBUG(llvm::dbgs()
@@ -798,7 +962,6 @@ struct GreedyFusion {
   /// No fusion is performed when producers with a user count greater than
   /// `maxSrcUserCount` for any of the memrefs involved.
   void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
-    LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
     // Skip if this node was removed (fused into another node).
     if (mdg->nodes.count(dstId) == 0)
       return;
@@ -858,8 +1021,17 @@ struct GreedyFusion {
             }))
           continue;
 
-        // Gather memrefs in 'srcNode' that are written and escape out of the
-        // block (e.g., memref block arguments, returned memrefs,
+        if (!checkVectorLoadStoreOps(srcId, dstId, producerConsumerMemrefs,
+                                     mdg)) {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Can't fuse: vector loop fusion invalid due to either "
+                        "src or dst ops are not affine vector ops or load "
+                        "dependent on a larger store region\n");
+          continue;
+        }
+
+        // Gather memrefs in 'srcNode' that are written and escape out of
+        // the block (e.g., memref block arguments, returned memrefs,
         // memrefs passed to function calls, etc.).
         DenseSet<Value> srcEscapingMemRefs;
         gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
@@ -907,7 +1079,6 @@ struct GreedyFusion {
             dstMemrefOps.push_back(op);
         unsigned dstLoopDepthTest =
             getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
-
         // Check the feasibility of fusing src loop nest into dst loop nest
         // at loop depths in range [1, dstLoopDepthTest].
         unsigned maxLegalFusionDepth = 0;
@@ -976,9 +1147,6 @@ struct GreedyFusion {
           if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
                                      removeSrcNode)) {
             // Create a private version of this memref.
-            LLVM_DEBUG(llvm::dbgs()
-                       << "Creating private memref for " << memref << '\n');
-            // Create a private version of this memref.
             privateMemrefs.insert(memref);
           }
         }
@@ -1019,9 +1187,19 @@ struct GreedyFusion {
             // private memref footprint.
             SmallVector<Operation *, 4> &storesForMemref =
                 memrefToStoresPair.second;
-            Value newMemRef = createPrivateMemRef(
-                dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
-                fastMemorySpace, localBufSizeThreshold);
+            Operation *srcStoreOpInst = storesForMemref[0];
+            Value newMemRef;
+
+            if (isa<AffineVectorLoadOp, AffineVectorStoreOp>(srcStoreOpInst)) {
+              newMemRef = createPrivateVectorOpMemRef(
+                  dstAffineForOp, srcStoreOpInst, bestDstLoopDepth,
+                  fastMemorySpace, localBufSizeThreshold);
+            } else {
+              newMemRef = createPrivateMemRef(dstAffineForOp, srcStoreOpInst,
+                                              bestDstLoopDepth, fastMemorySpace,
+                                              localBufSizeThreshold);
+            }
+
             // Create new node in dependence graph for 'newMemRef' alloc op.
             unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
             // Add edge from 'newMemRef' node to dstNode.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index ea144f73bb21c6..bba2d13cc14973 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s
 
 // Part I of fusion tests in  mlir/test/Transforms/loop-fusion.mlir.
 // Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir
@@ -285,3 +286,147 @@ module {
     spirv.ReturnValue %3 : !spirv.array<8192 x f32>
   }
 }
+
+// -----
+
+// Basic test for not fusing loops where a vector load depends on 
+// the entire result of a previous loop. store shape < load shape
+
+// CHECK-LABEL: func @should_not_fuse_across_memref_store_load_bounds
+func.func @should_not_fuse_across_memref_store_load_bounds() {
+  %a = memref.alloc() : memref<64x512xf32>
+  %b = memref.alloc() : memref<64x512xf32>
+  %c = memref.alloc() : memref<64x512xf32>
+  %d = memref.alloc() : memref<64x4096xf32>
+
+  affine.for %j = 0 to 8 {
+      %lhs = affine.vector_load %a[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %rhs = affine.vector_load %b[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+      affine.vector_store %res, %c[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+  }
+
+  affine.for %j = 0 to 8 {
+      %lhs = affine.vector_load %c[0, 0] : memref<64x512xf32>, vector<64x512xf32>
+      %rhs = affine.vector_load %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+      %res = arith.subf %lhs, %rhs : vector<64x512xf32>
+      affine.vector_store %res, %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+  }
+
+  return
+}
+// CHECK: %[[a:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[b:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[c:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[d:.*]] = memref.alloc() : memref<64x4096xf32>
+// CHECK: affine.for %[[j:.*]] = 0 to 8
+// CHECK: %[[lhs:.*]] = affine.vector_load %[[a]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[rhs:.*]] = affine.vector_load %[[b]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[res:.*]] = arith.addf %[[lhs]], %[[rhs]] : vector<64x64xf32>
+// CHECK: affine.vector_store %[[res]], %[[c]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: affine.for %[[j_2:.*]] = 0 to 8
+// CHECK: %[[lhs_2:.*]] = affine.vector_load %[[c]][0, 0] : memref<64x512xf32>, vector<64x512xf32>
+// CHECK: %[[rhs_2:.*]] = affine.vector_load %[[d]][0, %[[j_2]] * 512] : memref<64x4096xf32>, vector<64x512xf32>
+// CHECK: %[[res_2:.*]] = arith.subf %[[lhs_2]], %[[rhs_2]] : vector<64x512xf32>
+// CHECK: affine.vector_store %[[res_2]], %[[d]][0, %[[j_2]] * 512] : memref<64x4096xf32>, vector<64x512xf32>
+// CHECK: return
+
+// -----
+
+// Basic test for not fusing loops where the dependencies involve
+// an affine vector store and affine loads
+
+// CHECK-LABEL: func @should_not_fuse_vector_store_non_vector_load
+func.func @should_not_fuse_vector_store_non_vector_load() -> memref<64x4096xf32> {
+  %c0 = arith.constant 0 : index
+  %a = memref.alloc() : memref<64x512xf32> 
+  %b = memref.alloc() : memref<64x512xf32>
+  %c = memref.alloc() : memref<64x512xf32> 
+  %d = memref.alloc() : memref<64x4096xf32>
+
+  affine.for %j = 0 to 8 {
+    %lhs = affine.vector_load %a[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+    %rhs = affine.vector_load %b[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+    %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+    affine.vector_store %res, %c[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+  }
+
+  affine.for %k = 0 to 64 {
+    affine.for %m = 0 to 4096 {
+      affine.for %l = 0 to 512 {
+        %lhs = affine.load %c[%k, %l] : memref<64x512xf32>
+        %rhs = affine.load %d[%k, %m] : memref<64x4096xf32>
+        %res = arith.subf %lhs, %rhs : f32
+        affine.store %res, %d[%k, %m] : memref<64x4096xf32>
+      }
+    }
+  }
+
+  return %d : memref<64x4096xf32>
+}
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[a:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[b:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[c:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[d:.*]] = memref.alloc() : memref<64x4096xf32>
+// CHECK: affine.for %[[j:.*]] = 0 to 8 {
+// CHECK:   %[[lhs:.*]] = affine.vector_load %[[a]][%[[c0]], %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK:   %[[rhs:.*]] = affine.vector_load %[[b]][%[[c0]], %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK:   %[[res:.*]] = arith.addf %[[lhs]], %[[rhs]] : vector<64x64xf32>
+// CHECK:   affine.vector_store %[[res]], %[[c]][%[[c0]], %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: }
+// CHECK: affine.for %[[k:.*]] = 0 to 64 {
+// CHECK:   affine.for %[[l:.*]] = 0 to 4096 {
+// CHECK:     affine.for %[[m:.*]] = 0 to 512 {
+// CHECK:       %[[lhs_2:.*]] = affine.load %[[c]][%[[k]], %[[m]]] : memref<64x512xf32>
+// CHECK:       %[[rhs_2:.*]] = affine.load %[[d]][%[[k]], %[[l]]] : memref<64x4096xf32>
+// CHECK:       %[[res_2:.*]] = arith.subf %[[lhs_2]], %[[rhs_2]] : f32
+// CHECK:       affine.store %[[res_2]], %[[d]][%[[k]], %[[l]]] : memref<64x4096xf32>
+// CHECK:     }
+// CHECK:   }
+// CHECK: }
+// CHECK: return %[[d]] : memref<64x4096xf32>
+
+// -----
+
+// Basic test for fusing loops where a vector load depends on 
+// the partial result of a previous loop. store shape > load shape
+
+// CHECK-LABEL: func @should_fuse_across_memref_store_load_bounds
+func.func @should_fuse_across_memref_store_load_bounds() -> memref<64x4096xf32> {
+  %c0 = arith.constant 0 : index
+  %a = memref.alloc() : memref<64x512xf32> 
+  %b = memref.alloc() : memref<64x512xf32>
+  %c = memref.alloc() : memref<64x512xf32> 
+  %d = memref.alloc() : memref<64x4096xf32>
+
+  affine.for %j = 0 to 8 {
+      %lhs = affine.vector_load %a[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %rhs = affine.vector_load %b[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+      affine.vector_store %res, %c[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+  }
+
+  affine.for %j = 0 to 8 {
+      %lhs = affine.vector_load %c[%c0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %rhs = affine.vector_load %d[%c0, %j * 512] : memref<64x4096xf32>, vector<64x64xf32>
+      %res = arith.subf %lhs, %rhs : vector<64x64xf32>
+      affine.vector_store %res, %d[%c0, %j * 512] : memref<64x4096xf32>, vector<64x64xf32>
+  }
+  return %d : memref<64x4096xf32>
+}
+// CHECK: %[[private:.*]] = memref.alloc() : memref<64x64xf32>
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[a:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[b:.*]] = memref.alloc() : memref<64x512xf32>
+// CHECK: %[[c:.*]] = memref.alloc() : memref<64x4096xf32>
+// CHECK: affine.for %[[j:.*]] = 0 to 8
+// CHECK: %[[lhs:.*]] = affine.vector_load %[[a]][%[[c0]], %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[rhs:.*]] = affine.vector_load %[[b]][%[[c0]], %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[res:.*]] = arith.addf %[[lhs]], %[[rhs]] : vector<64x64xf32>
+// CHECK: affine.vector_store %[[res]], %[[private]][0, %[[j]] * 64] : memref<64x64xf32>, vector<64x64xf32>
+// CHECK: %[[lhs_2:.*]] = affine.vector_load %[[private]][0, %[[j]] * 64] : memref<64x64xf32>, vector<64x64xf32>
+// CHECK: %[[rhs_2:.*]] = affine.vector_load %[[c]][%[[c0]], %[[j]] * 512] : memref<64x4096xf32>, vector<64x64xf32>
+// CHECK: %[[res_2:.*]] = arith.subf %[[lhs_2]], %[[rhs_2]] : vector<64x64xf32>
+// CHECK: affine.vector_store %[[res_2]], %[[c]][%[[c0]], %[[j]] * 512] : memref<64x4096xf32>, vector<64x64xf32>
+// CHECK: return %[[c]] : memref<64x4096xf32>

>From e2d83980bb3a24c68c948066657063e3bbe13e1e Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Mon, 13 Jan 2025 14:47:11 -0700
Subject: [PATCH 2/3] remove unused imports and dead code

---
 .../Dialect/Affine/Transforms/LoopFusion.cpp  | 56 +------------------
 1 file changed, 3 insertions(+), 53 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 5039162ac5bc06..c61fe62e3889d9 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -10,7 +10,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
 #include "mlir/Dialect/Affine/Passes.h"
 
@@ -29,12 +28,10 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Types.h"
-#include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -187,56 +184,9 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
                                 producerConsumerMemrefs);
 }
 
-/// Checks the shapes of the loads and stores of each memref in
-/// the producer/consumer chains. If the load shapes are larger
-/// than the stores then we cannot fuse the loops. The loads
-/// would have a dependency on the values stored.
-static bool checkLoadStoreShapes(unsigned srcId, unsigned dstId,
-                                 DenseSet<Value> &producerConsumerMemrefs,
-                                 MemRefDependenceGraph *mdg) {
-  SmallVector<Operation *> storeOps;
-  SmallVector<Operation *> loadOps;
-
-  auto *srcNode = mdg->getNode(srcId);
-  auto *dstNode = mdg->getNode(dstId);
-
-  for (Value memref : producerConsumerMemrefs) {
-    srcNode->getStoreOpsForMemref(memref, &storeOps);
-    dstNode->getLoadOpsForMemref(memref, &loadOps);
-
-    for (Operation *storeOp : storeOps) {
-      Value storeValue =
-          cast<AffineWriteOpInterface>(storeOp).getValueToStore();
-      auto storeShapedType = dyn_cast<ShapedType>(storeValue.getType());
-
-      if (!storeShapedType)
-        continue;
-
-      for (Operation *loadOp : loadOps) {
-        Value loadValue = cast<AffineReadOpInterface>(loadOp).getValue();
-        auto loadShapedType = dyn_cast<ShapedType>(loadValue.getType());
-
-        if (!loadShapedType)
-          continue;
-
-        for (int i = 0; i < loadShapedType.getRank(); ++i) {
-          auto loadDim = loadShapedType.getDimSize(i);
-          auto storeDim = storeShapedType.getDimSize(i);
-
-          if (loadDim > storeDim)
-            return false;
-        }
-      }
-    }
-
-    storeOps.clear();
-    loadOps.clear();
-  }
-
-  return true;
-}
-
-/// Checks the shapes of the loads and stores of each memref in
+/// Performs two checks:
+/// Firstly, checks if both src/dst ops are vector operations.
+/// Secondly, the shapes of the loads and stores of each memref in
 /// the producer/consumer chains. If the load shapes are larger
 /// than the stores then we cannot fuse the loops. The loads
 /// would have a dependency on the values stored.

>From c49200cfa46a6d102d319e7b1e28fcf6421c1608 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Mon, 13 Jan 2025 14:57:29 -0700
Subject: [PATCH 3/3] adds doc to createPrivateVectorOpMemRef

---
 mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index c61fe62e3889d9..419209678d4085 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -299,6 +299,10 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
   node->op = newRootForOp;
 }
 
+/// Creates a private memref to be used by vector operations.
+/// TODO: The difference between this and 'createPrivateMemRef' is that
+/// the system for calculating the bounds and constraints doesn't
+/// support vector operations.
 static Value createPrivateVectorOpMemRef(
     AffineForOp forOp, Operation *srcStoreOpInst, unsigned dstLoopDepth,
     std::optional<unsigned> fastMemorySpace, uint64_t localBufSizeThreshold) {



More information about the Mlir-commits mailing list