[Mlir-commits] [mlir] [mlir][Affine] Fix vector fusion legality and buffer sizing (PR #167229)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 9 08:29:32 PST 2025


https://github.com/Men-cotton created https://github.com/llvm/llvm-project/pull/167229

Fixes #115849 and #115989

- Guard producer/consumer fusion with explicit vector-shape checks, rejecting mismatched loads/stores so vectorized legality errors can’t slip through.
- Size private fusion buffers conservatively for vector stores by threading a `minShape` through `MemRefRegion::getConstantBoundingSizeAndShape`, ensuring temporary memrefs are large enough to hold entire vector tiles.
- Extend `loop-fusion-vector.mlir` to cover mismatched-vector rejection, correct vector buffer sizing, scalar↔vector fusion, and sibling-mode mismatches.

>From 17204e804dfa1bf82b6c3577d1bc6da193329982 Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Mon, 3 Nov 2025 20:12:56 +0900
Subject: [PATCH] [mlir][Affine] Fix vector fusion legality and buffer sizing

---
 .../mlir/Dialect/Affine/Analysis/Utils.h      |  6 +-
 mlir/lib/Dialect/Affine/Analysis/Utils.cpp    | 11 ++-
 .../Dialect/Affine/Transforms/LoopFusion.cpp  | 25 ++++-
 .../Dialect/Affine/Utils/LoopFusionUtils.cpp  | 50 ++++++++++
 .../Dialect/Affine/loop-fusion-vector.mlir    | 97 +++++++++++++++++++
 5 files changed, 181 insertions(+), 8 deletions(-)
 create mode 100644 mlir/test/Dialect/Affine/loop-fusion-vector.mlir

diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index df4145db90a61..9ee85e4b19308 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -547,10 +547,12 @@ struct MemRefRegion {
   /// use int64_t instead of uint64_t since index types can be at most
   /// int64_t. `lbs` are set to the lower bound maps for each of the rank
   /// dimensions where each of these maps is purely symbolic in the constraints
-  /// set's symbols.
+  /// set's symbols. If `minShape` is provided, each computed bound is at least
+  /// `minShape[d]` for dimension `d`.
   std::optional<int64_t> getConstantBoundingSizeAndShape(
       SmallVectorImpl<int64_t> *shape = nullptr,
-      SmallVectorImpl<AffineMap> *lbs = nullptr) const;
+      SmallVectorImpl<AffineMap> *lbs = nullptr,
+      ArrayRef<int64_t> minShape = {}) const;
 
   /// Gets the lower and upper bound map for the dimensional variable at
   /// `pos`.
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index f38493bc9a96e..4e934a3b6e580 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -25,6 +25,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/raw_ostream.h"
+#include <algorithm>
 #include <optional>
 
 #define DEBUG_TYPE "analysis-utils"
@@ -1158,10 +1159,12 @@ unsigned MemRefRegion::getRank() const {
 }
 
 std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
-    SmallVectorImpl<int64_t> *shape, SmallVectorImpl<AffineMap> *lbs) const {
+    SmallVectorImpl<int64_t> *shape, SmallVectorImpl<AffineMap> *lbs,
+    ArrayRef<int64_t> minShape) const {
   auto memRefType = cast<MemRefType>(memref.getType());
   MLIRContext *context = memref.getContext();
   unsigned rank = memRefType.getRank();
+  assert(minShape.empty() || minShape.size() == rank);
   if (shape)
     shape->reserve(rank);
 
@@ -1203,12 +1206,14 @@ std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
       lb = AffineMap::get(/*dimCount=*/0, cstWithShapeBounds.getNumSymbolVars(),
                           /*result=*/getAffineConstantExpr(0, context));
     }
-    numElements *= diffConstant;
+    int64_t finalDiff =
+        minShape.empty() ? diffConstant : std::max(diffConstant, minShape[d]);
+    numElements *= finalDiff;
     // Populate outputs if available.
     if (lbs)
       lbs->push_back(lb);
     if (shape)
-      shape->push_back(diffConstant);
+      shape->push_back(finalDiff);
   }
   return numElements;
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index ff0157eb9e4f3..0fa140027b4c3 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -28,6 +28,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/raw_ostream.h"
+#include <algorithm>
 #include <iomanip>
 #include <optional>
 #include <sstream>
@@ -376,10 +377,28 @@ static Value createPrivateMemRef(AffineForOp forOp,
   SmallVector<int64_t, 4> newShape;
   SmallVector<AffineMap, 4> lbs;
   lbs.reserve(rank);
-  // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
-  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
+  SmallVector<int64_t, 4> minShape;
+  ArrayRef<int64_t> minShapeRef;
+  if (auto vectorStore = dyn_cast<AffineVectorStoreOp>(srcStoreOp)) {
+    ArrayRef<int64_t> vectorShape = vectorStore.getVectorType().getShape();
+    unsigned vectorRank = vectorShape.size();
+    if (vectorRank > rank) {
+      LDBG() << "Private memref creation unsupported for vector store with "
+             << "rank greater than memref rank";
+      return nullptr;
+    }
+    minShape.assign(rank, 0);
+    for (unsigned i = 0; i < vectorRank; ++i) {
+      unsigned memDim = rank - vectorRank + i;
+      int64_t vecDim = vectorShape[i];
+      assert(!ShapedType::isDynamic(vecDim) &&
+             "vector store should have static shape");
+      minShape[memDim] = std::max(minShape[memDim], vecDim);
+    }
+    minShapeRef = minShape;
+  }
   std::optional<int64_t> numElements =
-      region.getConstantBoundingSizeAndShape(&newShape, &lbs);
+      region.getConstantBoundingSizeAndShape(&newShape, &lbs, minShapeRef);
   assert(numElements && "non-constant number of elts in local buffer");
 
   const FlatAffineValueConstraints *cst = region.getConstraints();
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index c6abb0d734d88..3963fab97749b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -31,6 +31,22 @@
 using namespace mlir;
 using namespace mlir::affine;
 
+/// Returns the vector type associated with an affine vector load/store op.
+static std::optional<VectorType> getAffineVectorType(Operation *op) {
+  if (auto vectorLoad = dyn_cast<AffineVectorLoadOp>(op))
+    return vectorLoad.getVectorType();
+  if (auto vectorStore = dyn_cast<AffineVectorStoreOp>(op))
+    return vectorStore.getVectorType();
+  return std::nullopt;
+}
+
+/// Returns the memref underlying an affine read/write op.
+static Value getAccessMemRef(Operation *op) {
+  if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
+    return loadOp.getMemRef();
+  return cast<AffineWriteOpInterface>(op).getMemRef();
+}
+
 // Gathers all load and store memref accesses in 'opA' into 'values', where
 // 'values[memref] == true' for each store operation.
 static void getLoadAndStoreMemRefAccesses(Operation *opA,
@@ -334,6 +350,40 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
     break;
   }
 
+  // Guard vector fusion by matching producer/consumer vector shapes on actual
+  // dependence pairs (here we duplicate the early dependence check used in
+  // `computeSliceUnion` to avoid rejecting disjoint accesses).
+  for (Operation *srcOp : strategyOpsA) {
+    MemRefAccess srcAccess(srcOp);
+    auto srcVectorType = getAffineVectorType(srcOp);
+    bool srcIsRead = isa<AffineReadOpInterface>(srcOp);
+    for (Operation *dstOp : opsB) {
+      MemRefAccess dstAccess(dstOp);
+      if (srcAccess.memref != dstAccess.memref)
+        continue;
+      bool dstIsRead = isa<AffineReadOpInterface>(dstOp);
+      bool readReadAccesses = srcIsRead && dstIsRead;
+      DependenceResult result = checkMemrefAccessDependence(
+          srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
+          /*dependenceConstraints=*/nullptr,
+          /*dependenceComponents=*/nullptr, readReadAccesses);
+      if (result.value == DependenceResult::Failure) {
+        LDBG() << "Dependency check failed";
+        return FusionResult::FailPrecondition;
+      }
+      if (result.value == DependenceResult::NoDependence)
+        continue;
+      if (readReadAccesses)
+        continue;
+      auto dstVectorType = getAffineVectorType(dstOp);
+      if (srcVectorType && dstVectorType &&
+          srcVectorType->getShape() != dstVectorType->getShape()) {
+        LDBG() << "Mismatching vector shapes between producer and consumer";
+        return FusionResult::FailPrecondition;
+      }
+    }
+  }
+
   // Compute union of computation slices computed between all pairs of ops
   // from 'forOpA' and 'forOpB'.
   SliceComputationResult sliceComputationResult = affine::computeSliceUnion(
diff --git a/mlir/test/Dialect/Affine/loop-fusion-vector.mlir b/mlir/test/Dialect/Affine/loop-fusion-vector.mlir
new file mode 100644
index 0000000000000..f5dd13c36f8d3
--- /dev/null
+++ b/mlir/test/Dialect/Affine/loop-fusion-vector.mlir
@@ -0,0 +1,97 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING
+
+// CHECK-LABEL: func.func @skip_fusing_mismatched_vectors
+// CHECK: affine.for %{{.*}} = 0 to 8 {
+// CHECK:   affine.vector_store {{.*}} : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: }
+// CHECK: affine.for %{{.*}} = 0 to 8 {
+// CHECK:   affine.vector_load {{.*}} : memref<64x512xf32>, vector<64x512xf32>
+// CHECK: }
+func.func @skip_fusing_mismatched_vectors(%a: memref<64x512xf32>, %b: memref<64x512xf32>, %c: memref<64x512xf32>, %d: memref<64x4096xf32>, %e: memref<64x4096xf32>) {
+  affine.for %j = 0 to 8 {
+    %lhs = affine.vector_load %a[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+    %rhs = affine.vector_load %b[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+    %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+    affine.vector_store %res, %c[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+  }
+
+  affine.for %j = 0 to 8 {
+    %lhs = affine.vector_load %c[0, 0] : memref<64x512xf32>, vector<64x512xf32>
+    %rhs = affine.vector_load %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+    %res = arith.subf %lhs, %rhs : vector<64x512xf32>
+    affine.vector_store %res, %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_private_memref
+// CHECK: memref.alloc() : memref<1x64xf32>
+// CHECK-NOT: memref<1x1xf32>
+// CHECK: affine.vector_store {{.*}} : memref<1x64xf32>, vector<64xf32>
+func.func @vector_private_memref(%src: memref<10x64xf32>, %dst: memref<10x64xf32>) {
+  %tmp = memref.alloc() : memref<10x64xf32>
+  affine.for %i = 0 to 10 {
+    %vec = affine.vector_load %src[%i, 0] : memref<10x64xf32>, vector<64xf32>
+    affine.vector_store %vec, %tmp[%i, 0] : memref<10x64xf32>, vector<64xf32>
+  }
+
+  affine.for %i = 0 to 10 {
+    %vec = affine.vector_load %tmp[%i, 0] : memref<10x64xf32>, vector<64xf32>
+    affine.vector_store %vec, %dst[%i, 0] : memref<10x64xf32>, vector<64xf32>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fuse_scalar_vector
+// CHECK: %[[TMP:.*]] = memref.alloc() : memref<64xf32>
+// CHECK: affine.for %[[I:.*]] = 0 to 16 {
+// CHECK:   %[[S0:.*]] = affine.load %[[SRC:.*]][%[[I]] * 4] : memref<64xf32>
+// CHECK:   affine.store %[[S0]], %[[TMP]][%[[I]] * 4] : memref<64xf32>
+// CHECK:   %[[V:.*]] = affine.vector_load %[[TMP]][%[[I]] * 4] : memref<64xf32>, vector<4xf32>
+// CHECK:   affine.vector_store %[[V]], %[[DST:.*]][%[[I]] * 4] : memref<64xf32>, vector<4xf32>
+// CHECK: }
+func.func @fuse_scalar_vector(%src: memref<64xf32>, %dst: memref<64xf32>) {
+  %tmp = memref.alloc() : memref<64xf32>
+  affine.for %i = 0 to 16 {
+    %s0 = affine.load %src[%i * 4] : memref<64xf32>
+    affine.store %s0, %tmp[%i * 4] : memref<64xf32>
+    %s1 = affine.load %src[%i * 4 + 1] : memref<64xf32>
+    affine.store %s1, %tmp[%i * 4 + 1] : memref<64xf32>
+    %s2 = affine.load %src[%i * 4 + 2] : memref<64xf32>
+    affine.store %s2, %tmp[%i * 4 + 2] : memref<64xf32>
+    %s3 = affine.load %src[%i * 4 + 3] : memref<64xf32>
+    affine.store %s3, %tmp[%i * 4 + 3] : memref<64xf32>
+  }
+
+  affine.for %i = 0 to 16 {
+    %vec = affine.vector_load %tmp[%i * 4] : memref<64xf32>, vector<4xf32>
+    affine.vector_store %vec, %dst[%i * 4] : memref<64xf32>, vector<4xf32>
+  }
+  memref.dealloc %tmp : memref<64xf32>
+  return
+}
+
+// -----
+
+// SIBLING-LABEL: func.func @sibling_vector_mismatch
+// SIBLING: affine.for %{{.*}} = 0 to 10 {
+// SIBLING:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32>
+// SIBLING:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<8xf32>
+// SIBLING:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32>
+// SIBLING: }
+func.func @sibling_vector_mismatch(%src: memref<10x16xf32>) {
+  affine.for %i = 0 to 10 {
+    %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32>
+  }
+
+  affine.for %i = 0 to 10 {
+    %wide = affine.vector_load %src[%i, 8] : memref<10x16xf32>, vector<8xf32>
+    %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32>
+  }
+  return
+}



More information about the Mlir-commits mailing list