[Mlir-commits] [mlir] [mlir][Affine] Fix vector fusion legality and buffer sizing (PR #167229)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 9 08:29:32 PST 2025
https://github.com/Men-cotton created https://github.com/llvm/llvm-project/pull/167229
Fixes #115849 and #115989
- Guard producer/consumer fusion with explicit vector-shape checks, rejecting mismatched loads/stores so vectorized legality errors can’t slip through.
- Size private fusion buffers conservatively for vector stores by threading a `minShape` through `MemRefRegion::getConstantBoundingSizeAndShape`, ensuring temporary memrefs are large enough to hold entire vector tiles.
- Extend `loop-fusion-vector.mlir` to cover mismatched-vector rejection, correct vector buffer sizing, scalar↔vector fusion, and sibling-mode mismatches.
>From 17204e804dfa1bf82b6c3577d1bc6da193329982 Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Mon, 3 Nov 2025 20:12:56 +0900
Subject: [PATCH] [mlir][Affine] Fix vector fusion legality and buffer sizing
---
.../mlir/Dialect/Affine/Analysis/Utils.h | 6 +-
mlir/lib/Dialect/Affine/Analysis/Utils.cpp | 11 ++-
.../Dialect/Affine/Transforms/LoopFusion.cpp | 25 ++++-
.../Dialect/Affine/Utils/LoopFusionUtils.cpp | 50 ++++++++++
.../Dialect/Affine/loop-fusion-vector.mlir | 97 +++++++++++++++++++
5 files changed, 181 insertions(+), 8 deletions(-)
create mode 100644 mlir/test/Dialect/Affine/loop-fusion-vector.mlir
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index df4145db90a61..9ee85e4b19308 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -547,10 +547,12 @@ struct MemRefRegion {
/// use int64_t instead of uint64_t since index types can be at most
/// int64_t. `lbs` are set to the lower bound maps for each of the rank
/// dimensions where each of these maps is purely symbolic in the constraints
- /// set's symbols.
+ /// set's symbols. If `minShape` is provided, each computed bound is at least
+ /// `minShape[d]` for dimension `d`.
std::optional<int64_t> getConstantBoundingSizeAndShape(
SmallVectorImpl<int64_t> *shape = nullptr,
- SmallVectorImpl<AffineMap> *lbs = nullptr) const;
+ SmallVectorImpl<AffineMap> *lbs = nullptr,
+ ArrayRef<int64_t> minShape = {}) const;
/// Gets the lower and upper bound map for the dimensional variable at
/// `pos`.
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index f38493bc9a96e..4e934a3b6e580 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -25,6 +25,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
#include <optional>
#define DEBUG_TYPE "analysis-utils"
@@ -1158,10 +1159,12 @@ unsigned MemRefRegion::getRank() const {
}
std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
- SmallVectorImpl<int64_t> *shape, SmallVectorImpl<AffineMap> *lbs) const {
+ SmallVectorImpl<int64_t> *shape, SmallVectorImpl<AffineMap> *lbs,
+ ArrayRef<int64_t> minShape) const {
auto memRefType = cast<MemRefType>(memref.getType());
MLIRContext *context = memref.getContext();
unsigned rank = memRefType.getRank();
+ assert(minShape.empty() || minShape.size() == rank);
if (shape)
shape->reserve(rank);
@@ -1203,12 +1206,14 @@ std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
lb = AffineMap::get(/*dimCount=*/0, cstWithShapeBounds.getNumSymbolVars(),
/*result=*/getAffineConstantExpr(0, context));
}
- numElements *= diffConstant;
+ int64_t finalDiff =
+ minShape.empty() ? diffConstant : std::max(diffConstant, minShape[d]);
+ numElements *= finalDiff;
// Populate outputs if available.
if (lbs)
lbs->push_back(lb);
if (shape)
- shape->push_back(diffConstant);
+ shape->push_back(finalDiff);
}
return numElements;
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index ff0157eb9e4f3..0fa140027b4c3 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -28,6 +28,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
#include <iomanip>
#include <optional>
#include <sstream>
@@ -376,10 +377,28 @@ static Value createPrivateMemRef(AffineForOp forOp,
SmallVector<int64_t, 4> newShape;
SmallVector<AffineMap, 4> lbs;
lbs.reserve(rank);
- // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
- // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
+ SmallVector<int64_t, 4> minShape;
+ ArrayRef<int64_t> minShapeRef;
+ if (auto vectorStore = dyn_cast<AffineVectorStoreOp>(srcStoreOp)) {
+ ArrayRef<int64_t> vectorShape = vectorStore.getVectorType().getShape();
+ unsigned vectorRank = vectorShape.size();
+ if (vectorRank > rank) {
+ LDBG() << "Private memref creation unsupported for vector store with "
+ << "rank greater than memref rank";
+ return nullptr;
+ }
+ minShape.assign(rank, 0);
+ for (unsigned i = 0; i < vectorRank; ++i) {
+ unsigned memDim = rank - vectorRank + i;
+ int64_t vecDim = vectorShape[i];
+ assert(!ShapedType::isDynamic(vecDim) &&
+ "vector store should have static shape");
+ minShape[memDim] = std::max(minShape[memDim], vecDim);
+ }
+ minShapeRef = minShape;
+ }
std::optional<int64_t> numElements =
- region.getConstantBoundingSizeAndShape(&newShape, &lbs);
+ region.getConstantBoundingSizeAndShape(&newShape, &lbs, minShapeRef);
assert(numElements && "non-constant number of elts in local buffer");
const FlatAffineValueConstraints *cst = region.getConstraints();
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index c6abb0d734d88..3963fab97749b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -31,6 +31,22 @@
using namespace mlir;
using namespace mlir::affine;
+/// Returns the vector type associated with an affine vector load/store op.
+static std::optional<VectorType> getAffineVectorType(Operation *op) {
+ if (auto vectorLoad = dyn_cast<AffineVectorLoadOp>(op))
+ return vectorLoad.getVectorType();
+ if (auto vectorStore = dyn_cast<AffineVectorStoreOp>(op))
+ return vectorStore.getVectorType();
+ return std::nullopt;
+}
+
+/// Returns the memref underlying an affine read/write op.
+static Value getAccessMemRef(Operation *op) {
+ if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
+ return loadOp.getMemRef();
+ return cast<AffineWriteOpInterface>(op).getMemRef();
+}
+
// Gathers all load and store memref accesses in 'opA' into 'values', where
// 'values[memref] == true' for each store operation.
static void getLoadAndStoreMemRefAccesses(Operation *opA,
@@ -334,6 +350,40 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
break;
}
+ // Guard vector fusion by matching producer/consumer vector shapes on actual
+ // dependence pairs (here we duplicate the early dependence check used in
+ // `computeSliceUnion` to avoid rejecting disjoint accesses).
+ for (Operation *srcOp : strategyOpsA) {
+ MemRefAccess srcAccess(srcOp);
+ auto srcVectorType = getAffineVectorType(srcOp);
+ bool srcIsRead = isa<AffineReadOpInterface>(srcOp);
+ for (Operation *dstOp : opsB) {
+ MemRefAccess dstAccess(dstOp);
+ if (srcAccess.memref != dstAccess.memref)
+ continue;
+ bool dstIsRead = isa<AffineReadOpInterface>(dstOp);
+ bool readReadAccesses = srcIsRead && dstIsRead;
+ DependenceResult result = checkMemrefAccessDependence(
+ srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
+ /*dependenceConstraints=*/nullptr,
+ /*dependenceComponents=*/nullptr, readReadAccesses);
+ if (result.value == DependenceResult::Failure) {
+ LDBG() << "Dependency check failed";
+ return FusionResult::FailPrecondition;
+ }
+ if (result.value == DependenceResult::NoDependence)
+ continue;
+ if (readReadAccesses)
+ continue;
+ auto dstVectorType = getAffineVectorType(dstOp);
+ if (srcVectorType && dstVectorType &&
+ srcVectorType->getShape() != dstVectorType->getShape()) {
+ LDBG() << "Mismatching vector shapes between producer and consumer";
+ return FusionResult::FailPrecondition;
+ }
+ }
+ }
+
// Compute union of computation slices computed between all pairs of ops
// from 'forOpA' and 'forOpB'.
SliceComputationResult sliceComputationResult = affine::computeSliceUnion(
diff --git a/mlir/test/Dialect/Affine/loop-fusion-vector.mlir b/mlir/test/Dialect/Affine/loop-fusion-vector.mlir
new file mode 100644
index 0000000000000..f5dd13c36f8d3
--- /dev/null
+++ b/mlir/test/Dialect/Affine/loop-fusion-vector.mlir
@@ -0,0 +1,97 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING
+
+// CHECK-LABEL: func.func @skip_fusing_mismatched_vectors
+// CHECK: affine.for %{{.*}} = 0 to 8 {
+// CHECK: affine.vector_store {{.*}} : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: }
+// CHECK: affine.for %{{.*}} = 0 to 8 {
+// CHECK: affine.vector_load {{.*}} : memref<64x512xf32>, vector<64x512xf32>
+// CHECK: }
+func.func @skip_fusing_mismatched_vectors(%a: memref<64x512xf32>, %b: memref<64x512xf32>, %c: memref<64x512xf32>, %d: memref<64x4096xf32>, %e: memref<64x4096xf32>) {
+ affine.for %j = 0 to 8 {
+ %lhs = affine.vector_load %a[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ %rhs = affine.vector_load %b[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+ affine.vector_store %res, %c[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ }
+
+ affine.for %j = 0 to 8 {
+ %lhs = affine.vector_load %c[0, 0] : memref<64x512xf32>, vector<64x512xf32>
+ %rhs = affine.vector_load %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+ %res = arith.subf %lhs, %rhs : vector<64x512xf32>
+ affine.vector_store %res, %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_private_memref
+// CHECK: memref.alloc() : memref<1x64xf32>
+// CHECK-NOT: memref<1x1xf32>
+// CHECK: affine.vector_store {{.*}} : memref<1x64xf32>, vector<64xf32>
+func.func @vector_private_memref(%src: memref<10x64xf32>, %dst: memref<10x64xf32>) {
+ %tmp = memref.alloc() : memref<10x64xf32>
+ affine.for %i = 0 to 10 {
+ %vec = affine.vector_load %src[%i, 0] : memref<10x64xf32>, vector<64xf32>
+ affine.vector_store %vec, %tmp[%i, 0] : memref<10x64xf32>, vector<64xf32>
+ }
+
+ affine.for %i = 0 to 10 {
+ %vec = affine.vector_load %tmp[%i, 0] : memref<10x64xf32>, vector<64xf32>
+ affine.vector_store %vec, %dst[%i, 0] : memref<10x64xf32>, vector<64xf32>
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fuse_scalar_vector
+// CHECK: %[[TMP:.*]] = memref.alloc() : memref<64xf32>
+// CHECK: affine.for %[[I:.*]] = 0 to 16 {
+// CHECK: %[[S0:.*]] = affine.load %[[SRC:.*]][%[[I]] * 4] : memref<64xf32>
+// CHECK: affine.store %[[S0]], %[[TMP]][%[[I]] * 4] : memref<64xf32>
+// CHECK: %[[V:.*]] = affine.vector_load %[[TMP]][%[[I]] * 4] : memref<64xf32>, vector<4xf32>
+// CHECK: affine.vector_store %[[V]], %[[DST:.*]][%[[I]] * 4] : memref<64xf32>, vector<4xf32>
+// CHECK: }
+func.func @fuse_scalar_vector(%src: memref<64xf32>, %dst: memref<64xf32>) {
+ %tmp = memref.alloc() : memref<64xf32>
+ affine.for %i = 0 to 16 {
+ %s0 = affine.load %src[%i * 4] : memref<64xf32>
+ affine.store %s0, %tmp[%i * 4] : memref<64xf32>
+ %s1 = affine.load %src[%i * 4 + 1] : memref<64xf32>
+ affine.store %s1, %tmp[%i * 4 + 1] : memref<64xf32>
+ %s2 = affine.load %src[%i * 4 + 2] : memref<64xf32>
+ affine.store %s2, %tmp[%i * 4 + 2] : memref<64xf32>
+ %s3 = affine.load %src[%i * 4 + 3] : memref<64xf32>
+ affine.store %s3, %tmp[%i * 4 + 3] : memref<64xf32>
+ }
+
+ affine.for %i = 0 to 16 {
+ %vec = affine.vector_load %tmp[%i * 4] : memref<64xf32>, vector<4xf32>
+ affine.vector_store %vec, %dst[%i * 4] : memref<64xf32>, vector<4xf32>
+ }
+ memref.dealloc %tmp : memref<64xf32>
+ return
+}
+
+// -----
+
+// SIBLING-LABEL: func.func @sibling_vector_mismatch
+// SIBLING: affine.for %{{.*}} = 0 to 10 {
+// SIBLING: affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32>
+// SIBLING: affine.vector_load %{{.*}} : memref<10x16xf32>, vector<8xf32>
+// SIBLING: affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32>
+// SIBLING: }
+func.func @sibling_vector_mismatch(%src: memref<10x16xf32>) {
+ affine.for %i = 0 to 10 {
+ %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32>
+ }
+
+ affine.for %i = 0 to 10 {
+ %wide = affine.vector_load %src[%i, 8] : memref<10x16xf32>, vector<8xf32>
+ %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32>
+ }
+ return
+}
More information about the Mlir-commits
mailing list