[llvm] [mlir] [mlir][Affine] Fix vector fusion legality and buffer sizing (PR #167229)

via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 13 04:11:41 PST 2025


https://github.com/Men-cotton updated https://github.com/llvm/llvm-project/pull/167229

>From 17204e804dfa1bf82b6c3577d1bc6da193329982 Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Mon, 3 Nov 2025 20:12:56 +0900
Subject: [PATCH 1/4] [mlir][Affine] Fix vector fusion legality and buffer
 sizing

---
 .../mlir/Dialect/Affine/Analysis/Utils.h      |  6 +-
 mlir/lib/Dialect/Affine/Analysis/Utils.cpp    | 11 ++-
 .../Dialect/Affine/Transforms/LoopFusion.cpp  | 25 ++++-
 .../Dialect/Affine/Utils/LoopFusionUtils.cpp  | 50 ++++++++++
 .../Dialect/Affine/loop-fusion-vector.mlir    | 97 +++++++++++++++++++
 5 files changed, 181 insertions(+), 8 deletions(-)
 create mode 100644 mlir/test/Dialect/Affine/loop-fusion-vector.mlir

diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index df4145db90a61..9ee85e4b19308 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -547,10 +547,12 @@ struct MemRefRegion {
   /// use int64_t instead of uint64_t since index types can be at most
   /// int64_t. `lbs` are set to the lower bound maps for each of the rank
   /// dimensions where each of these maps is purely symbolic in the constraints
-  /// set's symbols.
+  /// set's symbols. If `minShape` is provided, each computed bound is at least
+  /// `minShape[d]` for dimension `d`.
   std::optional<int64_t> getConstantBoundingSizeAndShape(
       SmallVectorImpl<int64_t> *shape = nullptr,
-      SmallVectorImpl<AffineMap> *lbs = nullptr) const;
+      SmallVectorImpl<AffineMap> *lbs = nullptr,
+      ArrayRef<int64_t> minShape = {}) const;
 
   /// Gets the lower and upper bound map for the dimensional variable at
   /// `pos`.
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index f38493bc9a96e..4e934a3b6e580 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -25,6 +25,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/raw_ostream.h"
+#include <algorithm>
 #include <optional>
 
 #define DEBUG_TYPE "analysis-utils"
@@ -1158,10 +1159,12 @@ unsigned MemRefRegion::getRank() const {
 }
 
 std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
-    SmallVectorImpl<int64_t> *shape, SmallVectorImpl<AffineMap> *lbs) const {
+    SmallVectorImpl<int64_t> *shape, SmallVectorImpl<AffineMap> *lbs,
+    ArrayRef<int64_t> minShape) const {
   auto memRefType = cast<MemRefType>(memref.getType());
   MLIRContext *context = memref.getContext();
   unsigned rank = memRefType.getRank();
+  assert(minShape.empty() || minShape.size() == rank);
   if (shape)
     shape->reserve(rank);
 
@@ -1203,12 +1206,14 @@ std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
       lb = AffineMap::get(/*dimCount=*/0, cstWithShapeBounds.getNumSymbolVars(),
                           /*result=*/getAffineConstantExpr(0, context));
     }
-    numElements *= diffConstant;
+    int64_t finalDiff =
+        minShape.empty() ? diffConstant : std::max(diffConstant, minShape[d]);
+    numElements *= finalDiff;
     // Populate outputs if available.
     if (lbs)
       lbs->push_back(lb);
     if (shape)
-      shape->push_back(diffConstant);
+      shape->push_back(finalDiff);
   }
   return numElements;
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index ff0157eb9e4f3..0fa140027b4c3 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -28,6 +28,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/raw_ostream.h"
+#include <algorithm>
 #include <iomanip>
 #include <optional>
 #include <sstream>
@@ -376,10 +377,28 @@ static Value createPrivateMemRef(AffineForOp forOp,
   SmallVector<int64_t, 4> newShape;
   SmallVector<AffineMap, 4> lbs;
   lbs.reserve(rank);
-  // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
-  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
+  SmallVector<int64_t, 4> minShape;
+  ArrayRef<int64_t> minShapeRef;
+  if (auto vectorStore = dyn_cast<AffineVectorStoreOp>(srcStoreOp)) {
+    ArrayRef<int64_t> vectorShape = vectorStore.getVectorType().getShape();
+    unsigned vectorRank = vectorShape.size();
+    if (vectorRank > rank) {
+      LDBG() << "Private memref creation unsupported for vector store with "
+             << "rank greater than memref rank";
+      return nullptr;
+    }
+    minShape.assign(rank, 0);
+    for (unsigned i = 0; i < vectorRank; ++i) {
+      unsigned memDim = rank - vectorRank + i;
+      int64_t vecDim = vectorShape[i];
+      assert(!ShapedType::isDynamic(vecDim) &&
+             "vector store should have static shape");
+      minShape[memDim] = std::max(minShape[memDim], vecDim);
+    }
+    minShapeRef = minShape;
+  }
   std::optional<int64_t> numElements =
-      region.getConstantBoundingSizeAndShape(&newShape, &lbs);
+      region.getConstantBoundingSizeAndShape(&newShape, &lbs, minShapeRef);
   assert(numElements && "non-constant number of elts in local buffer");
 
   const FlatAffineValueConstraints *cst = region.getConstraints();
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index c6abb0d734d88..3963fab97749b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -31,6 +31,22 @@
 using namespace mlir;
 using namespace mlir::affine;
 
+/// Returns the vector type associated with an affine vector load/store op.
+static std::optional<VectorType> getAffineVectorType(Operation *op) {
+  if (auto vectorLoad = dyn_cast<AffineVectorLoadOp>(op))
+    return vectorLoad.getVectorType();
+  if (auto vectorStore = dyn_cast<AffineVectorStoreOp>(op))
+    return vectorStore.getVectorType();
+  return std::nullopt;
+}
+
+/// Returns the memref underlying an affine read/write op.
+static Value getAccessMemRef(Operation *op) {
+  if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
+    return loadOp.getMemRef();
+  return cast<AffineWriteOpInterface>(op).getMemRef();
+}
+
 // Gathers all load and store memref accesses in 'opA' into 'values', where
 // 'values[memref] == true' for each store operation.
 static void getLoadAndStoreMemRefAccesses(Operation *opA,
@@ -334,6 +350,40 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
     break;
   }
 
+  // Guard vector fusion by matching producer/consumer vector shapes on actual
+  // dependence pairs (here we duplicate the early dependence check used in
+  // `computeSliceUnion` to avoid rejecting disjoint accesses).
+  for (Operation *srcOp : strategyOpsA) {
+    MemRefAccess srcAccess(srcOp);
+    auto srcVectorType = getAffineVectorType(srcOp);
+    bool srcIsRead = isa<AffineReadOpInterface>(srcOp);
+    for (Operation *dstOp : opsB) {
+      MemRefAccess dstAccess(dstOp);
+      if (srcAccess.memref != dstAccess.memref)
+        continue;
+      bool dstIsRead = isa<AffineReadOpInterface>(dstOp);
+      bool readReadAccesses = srcIsRead && dstIsRead;
+      DependenceResult result = checkMemrefAccessDependence(
+          srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
+          /*dependenceConstraints=*/nullptr,
+          /*dependenceComponents=*/nullptr, readReadAccesses);
+      if (result.value == DependenceResult::Failure) {
+        LDBG() << "Dependency check failed";
+        return FusionResult::FailPrecondition;
+      }
+      if (result.value == DependenceResult::NoDependence)
+        continue;
+      if (readReadAccesses)
+        continue;
+      auto dstVectorType = getAffineVectorType(dstOp);
+      if (srcVectorType && dstVectorType &&
+          srcVectorType->getShape() != dstVectorType->getShape()) {
+        LDBG() << "Mismatching vector shapes between producer and consumer";
+        return FusionResult::FailPrecondition;
+      }
+    }
+  }
+
   // Compute union of computation slices computed between all pairs of ops
   // from 'forOpA' and 'forOpB'.
   SliceComputationResult sliceComputationResult = affine::computeSliceUnion(
diff --git a/mlir/test/Dialect/Affine/loop-fusion-vector.mlir b/mlir/test/Dialect/Affine/loop-fusion-vector.mlir
new file mode 100644
index 0000000000000..f5dd13c36f8d3
--- /dev/null
+++ b/mlir/test/Dialect/Affine/loop-fusion-vector.mlir
@@ -0,0 +1,97 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING
+
+// CHECK-LABEL: func.func @skip_fusing_mismatched_vectors
+// CHECK: affine.for %{{.*}} = 0 to 8 {
+// CHECK:   affine.vector_store {{.*}} : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: }
+// CHECK: affine.for %{{.*}} = 0 to 8 {
+// CHECK:   affine.vector_load {{.*}} : memref<64x512xf32>, vector<64x512xf32>
+// CHECK: }
+func.func @skip_fusing_mismatched_vectors(%a: memref<64x512xf32>, %b: memref<64x512xf32>, %c: memref<64x512xf32>, %d: memref<64x4096xf32>, %e: memref<64x4096xf32>) {
+  affine.for %j = 0 to 8 {
+    %lhs = affine.vector_load %a[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+    %rhs = affine.vector_load %b[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+    %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+    affine.vector_store %res, %c[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+  }
+
+  affine.for %j = 0 to 8 {
+    %lhs = affine.vector_load %c[0, 0] : memref<64x512xf32>, vector<64x512xf32>
+    %rhs = affine.vector_load %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+    %res = arith.subf %lhs, %rhs : vector<64x512xf32>
+    affine.vector_store %res, %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_private_memref
+// CHECK: memref.alloc() : memref<1x64xf32>
+// CHECK-NOT: memref<1x1xf32>
+// CHECK: affine.vector_store {{.*}} : memref<1x64xf32>, vector<64xf32>
+func.func @vector_private_memref(%src: memref<10x64xf32>, %dst: memref<10x64xf32>) {
+  %tmp = memref.alloc() : memref<10x64xf32>
+  affine.for %i = 0 to 10 {
+    %vec = affine.vector_load %src[%i, 0] : memref<10x64xf32>, vector<64xf32>
+    affine.vector_store %vec, %tmp[%i, 0] : memref<10x64xf32>, vector<64xf32>
+  }
+
+  affine.for %i = 0 to 10 {
+    %vec = affine.vector_load %tmp[%i, 0] : memref<10x64xf32>, vector<64xf32>
+    affine.vector_store %vec, %dst[%i, 0] : memref<10x64xf32>, vector<64xf32>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fuse_scalar_vector
+// CHECK: %[[TMP:.*]] = memref.alloc() : memref<64xf32>
+// CHECK: affine.for %[[I:.*]] = 0 to 16 {
+// CHECK:   %[[S0:.*]] = affine.load %[[SRC:.*]][%[[I]] * 4] : memref<64xf32>
+// CHECK:   affine.store %[[S0]], %[[TMP]][%[[I]] * 4] : memref<64xf32>
+// CHECK:   %[[V:.*]] = affine.vector_load %[[TMP]][%[[I]] * 4] : memref<64xf32>, vector<4xf32>
+// CHECK:   affine.vector_store %[[V]], %[[DST:.*]][%[[I]] * 4] : memref<64xf32>, vector<4xf32>
+// CHECK: }
+func.func @fuse_scalar_vector(%src: memref<64xf32>, %dst: memref<64xf32>) {
+  %tmp = memref.alloc() : memref<64xf32>
+  affine.for %i = 0 to 16 {
+    %s0 = affine.load %src[%i * 4] : memref<64xf32>
+    affine.store %s0, %tmp[%i * 4] : memref<64xf32>
+    %s1 = affine.load %src[%i * 4 + 1] : memref<64xf32>
+    affine.store %s1, %tmp[%i * 4 + 1] : memref<64xf32>
+    %s2 = affine.load %src[%i * 4 + 2] : memref<64xf32>
+    affine.store %s2, %tmp[%i * 4 + 2] : memref<64xf32>
+    %s3 = affine.load %src[%i * 4 + 3] : memref<64xf32>
+    affine.store %s3, %tmp[%i * 4 + 3] : memref<64xf32>
+  }
+
+  affine.for %i = 0 to 16 {
+    %vec = affine.vector_load %tmp[%i * 4] : memref<64xf32>, vector<4xf32>
+    affine.vector_store %vec, %dst[%i * 4] : memref<64xf32>, vector<4xf32>
+  }
+  memref.dealloc %tmp : memref<64xf32>
+  return
+}
+
+// -----
+
+// SIBLING-LABEL: func.func @sibling_vector_mismatch
+// SIBLING: affine.for %{{.*}} = 0 to 10 {
+// SIBLING:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32>
+// SIBLING:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<8xf32>
+// SIBLING:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32>
+// SIBLING: }
+func.func @sibling_vector_mismatch(%src: memref<10x16xf32>) {
+  affine.for %i = 0 to 10 {
+    %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32>
+  }
+
+  affine.for %i = 0 to 10 {
+    %wide = affine.vector_load %src[%i, 8] : memref<10x16xf32>, vector<8xf32>
+    %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32>
+  }
+  return
+}

>From c38963745cd539f990d25f29a5f77fa229c28e7d Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Mon, 10 Nov 2025 01:31:36 +0900
Subject: [PATCH 2/4] Fix: clang-format

---
 mlir/include/mlir/Dialect/Affine/Analysis/Utils.h | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index 9ee85e4b19308..daf7976118cea 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -549,10 +549,10 @@ struct MemRefRegion {
   /// dimensions where each of these maps is purely symbolic in the constraints
   /// set's symbols. If `minShape` is provided, each computed bound is at least
   /// `minShape[d]` for dimension `d`.
-  std::optional<int64_t> getConstantBoundingSizeAndShape(
-      SmallVectorImpl<int64_t> *shape = nullptr,
-      SmallVectorImpl<AffineMap> *lbs = nullptr,
-      ArrayRef<int64_t> minShape = {}) const;
+  std::optional<int64_t>
+  getConstantBoundingSizeAndShape(SmallVectorImpl<int64_t> *shape = nullptr,
+                                  SmallVectorImpl<AffineMap> *lbs = nullptr,
+                                  ArrayRef<int64_t> minShape = {}) const;
 
   /// Gets the lower and upper bound map for the dimensional variable at
   /// `pos`.

>From 65f006e79c8439de95b7a37530b651395712c93b Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Mon, 10 Nov 2025 01:51:30 +0900
Subject: [PATCH 3/4] Fix: remove unused function

---
 mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 3963fab97749b..48d1db15d84cb 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -40,13 +40,6 @@ static std::optional<VectorType> getAffineVectorType(Operation *op) {
   return std::nullopt;
 }
 
-/// Returns the memref underlying an affine read/write op.
-static Value getAccessMemRef(Operation *op) {
-  if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
-    return loadOp.getMemRef();
-  return cast<AffineWriteOpInterface>(op).getMemRef();
-}
-
 // Gathers all load and store memref accesses in 'opA' into 'values', where
 // 'values[memref] == true' for each store operation.
 static void getLoadAndStoreMemRefAccesses(Operation *opA,

>From ced8c31ad33c10ca2291bb45bb68809b8d087613 Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Sat, 13 Dec 2025 21:11:07 +0900
Subject: [PATCH 4/4] Fix: refine and separate tests

---
 .agent/AGENTS.md                              | 123 +++
 .agent/TestingGuide.md                        | 831 ++++++++++++++++++
 .../Dialect/Affine/loop-fusion-sibling.mlir   |  23 +-
 .../Dialect/Affine/loop-fusion-vector.mlir    |  59 +-
 4 files changed, 999 insertions(+), 37 deletions(-)
 create mode 100644 .agent/AGENTS.md
 create mode 100644 .agent/TestingGuide.md

diff --git a/.agent/AGENTS.md b/.agent/AGENTS.md
new file mode 100644
index 0000000000000..cad18898c8896
--- /dev/null
+++ b/.agent/AGENTS.md
@@ -0,0 +1,123 @@
+# ExecPlans
+ 
+When writing complex features or significant refactors, use an ExecPlan (as described in .agent/PLANS.md) from design to implementation.
+
+# Useful Commands
+
+## Configure
+```
+cmake -S llvm -B build-format -G Ninja \
+    -DCLANG_INCLUDE_TESTS=ON \
+    -DCMAKE_BUILD_TYPE=RelWithDebInfo \
+    -DCMAKE_C_COMPILER_LAUNCHER=ccache \
+    -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
+    -DCMAKE_CXX_FLAGS="-Werror=unused-variable" \
+    -DLLVM_ENABLE_ASSERTIONS=ON \
+    -DLLVM_ENABLE_LLD=ON \
+    -DLLVM_ENABLE_PROJECTS="clang" \
+    -DLLVM_INCLUDE_TESTS=ON \
+    -DLLVM_PARALLEL_COMPILE_JOBS=10 \
+    -DLLVM_PARALLEL_LINK_JOBS=2 \
+    -DLLVM_TARGETS_TO_BUILD="Native" \
+```
+
+## Build
+`ninja -C build-format FormatTests`
+
+## Test
+- `./build-format/tools/clang/unittests/Format/FormatTests`
+- `./build-format/tools/clang/unittests/Format/FormatTests --gtest_filter=FormatTestComments.*`
+- `printf '/* comment\n */' | ./build-format/bin/clang-format -style='{SpaceInComments: {BeforeClosingComment: Never}}'`
+
+## Format document
+1. `cd clang/docs/tools && python3 dump_format_style.py`
+
+the following is optional.
+
+2. `cd ../../../ && mkdir -p html`
+3. `sphinx-build -n ./clang/docs ./html`
+
+# **RFC: Advanced Comment Spacing Control in clang-format**
+
+This document outlines a plan to replace the legacy `SpaceBeforeClosingBlockComment` boolean with a more powerful and flexible `SpaceInCommentsOptions` struct. The primary goal is to provide fine-grained control over whitespace within block comments and to integrate this logic cleanly into the reflow process, eliminating the need for special handling in the lexer.
+
+## **High-Level Plan**
+
+The proposal is to introduce a new `FormatStyle` field, `SpaceInComments`, controlled by a new `SpaceInCommentsOptions` struct. This will offer precise control over spacing at the start and end of different types of block comments. The entire implementation will be centered around the reflow logic, modifying the comment token's text directly at format time.
+
+*   **In Scope:**
+    *   A new `SpaceInCommentsOptions` struct (surfaced as `FormatStyle::SpaceInComments`) with four controls: `AfterOpeningComment`, `BeforeClosingComment`, `AfterOpeningParamComment`, and `BeforeClosingParamComment`. These will be controlled by a three-state `CommentSpaceMode` enum (`Leave`, `Always`, `Never`).
+    *   Centralized comment classification logic to distinguish between plain, parameter, and docstring comments so the spacing controls can specialize their behavior.
+    *   A clean integration into the reflow logic in `BreakableToken.cpp`, where the comment token's text will be directly rewritten.
+    *   Deprecation and removal of the legacy `SpaceBeforeClosingBlockComment` boolean.
+    *   Comprehensive documentation and testing for the new feature.
+*   **Out of Scope:**
+    *   Changes to star alignment in multi-line doc comments.
+    *   Spacing controls for line comments (`//`).
+
+---
+
+## **Implementation Plan**
+
+### **1. Data Model and Configuration**
+
+First, we will establish the data structures in `clang/include/clang/Format/Format.h`.
+
+*   **`enum class CommentSpaceMode`**: This will be the control enum with three states:
+    *   `Leave`: Preserves existing whitespace.
+    *   `Always`: Enforces a single space.
+    *   `Never`: Removes all horizontal whitespace.
+    This enum will be exposed to YAML configuration via `ScalarEnumerationTraits` in `Format.cpp`.
+
+*   **`struct SpaceInCommentsOptions`**: This struct will group the four independent controls and be added to `FormatStyle`.
+    ```c++
+    struct SpaceInCommentsOptions {
+      CommentSpaceMode AfterOpeningComment = CommentSpaceMode::Leave;
+      CommentSpaceMode BeforeClosingComment = CommentSpaceMode::Leave;
+      CommentSpaceMode AfterOpeningParamComment = CommentSpaceMode::Leave;
+      CommentSpaceMode BeforeClosingParamComment = CommentSpaceMode::Leave;
+    };
+
+    // In FormatStyle:
+    SpaceInCommentsOptions SpaceInComments;
+    ```
+    All preset styles will use the `Leave` default to ensure no behavior changes for existing users. `ClangFormatStyleOptions.rst` will be updated to document the new nested structure, and `dump_format_style.py` will be modified to emit it.
+
+### **2. Centralized Comment Classification**
+
+To apply different rules, we need to classify comments. This will be done in the lexer.
+
+*   In `FormatToken.h`, we will add `enum class CommentKind { Plain, DocString, Parameter };` and a corresponding member to `FormatToken` to store the classification for each block comment.
+*   In `FormatTokenLexer::getNextToken()`, a new helper function, `classifyBlockComment`, will be implemented. This function will determine the kind of each block comment:
+    *   **Docstrings**: Comments starting with `/**` or `/*!` will be classified as `CommentKind::DocString`. This ensures we can avoid modifying the indentation of Javadoc/Doxygen-style comments.
+    *   **Parameter comments**: A simple heuristic will be used: if the comment's content (after stripping `/*` and `*/`) ends with `=`, it will be classified as `CommentKind::Parameter`. This targets the common `/*name=*/` pattern.
+    *   **Plain comments**: All other block comments will be `CommentKind::Plain`.
+
+### **3. Reflow-Integrated Formatting Logic**
+
+This is the core of the implementation. All logic for adding or removing spaces inside comments will reside within the reflow process, specifically in `BreakableToken.cpp`. This avoids polluting the lexer with formatting state.
+
+*   **Core Principle**: Instead of tracking spacing information from the lexer, we will modify the `TokenText` of the comment token directly during the reflow. Since `FormatToken::TokenText` is a `StringRef` and cannot be modified, we will generate new text in an `ArenaAllocator` and update the `TokenText` to point to it.
+
+*   **Implementation Details**:
+    1.  Two new helper functions will be created in `BreakableToken.cpp`: `applyAfterOpeningBlockCommentSpacing` and `applyBeforeClosingBlockCommentSpacing`.
+    2.  These functions will be called from `BreakableBlockComment::adaptStartOfLine` and other relevant locations.
+    3.  The logic will operate on the full text of the comment token. It will split the comment into lines to intelligently handle single-line vs. multi-line cases.
+        *   For multi-line comments, spacing adjustments will only consider the content of the first and last lines.
+        *   **Leading whitespace**: Will only be added or removed if the **first line** of the block comment contains non-whitespace characters.
+        *   **Trailing whitespace**: Will only be added or removed if the **last line** of the block comment contains non-whitespace characters.
+    4.  A resolver function, `resolveCommentSpaceMode`, will select the appropriate `CommentSpaceMode` (`AfterOpeningComment` vs. `AfterOpeningParamComment`, etc.) based on the token's `CommentKind`. Docstrings will always be treated as `Leave` for the opening `/*` to protect leading star alignment.
+    5.  `WhitespaceManager::replaceWhitespaceInToken` will be used to apply the changes to the token.
+
+### **4. Deprecation and Finalization**
+
+Once the new logic is in place, we will remove the old `SpaceBeforeClosingBlockComment` boolean from `FormatStyle` and all associated logic. All documentation will be updated to reflect the new `SpaceInComments` option.
+
+### **5. Testing Strategy**
+
+The implementation will be validated by a multi-layered test suite to ensure correctness and prevent regressions.
+
+*   **Unit tests (`FormatTestComments.cpp`)**: Add extensive tests covering all combinations of the four new options for various comment types (single-line, multi-line, whitespace-only, parameter comments, docstrings). New tests like `InsertsSpaceAfterOpeningBlockComment`, `AfterOpeningParamCommentOverrides`, and `BeforeClosingParamCommentModes` will ensure each knob behaves independently and correctly interacts with the reflow logic.
+*   **Config parsing (`ConfigParseTest.cpp`)**: Ensure the nested YAML mapping for `SpaceInComments` can be correctly serialized and deserialized.
+*   **Lexing/annotation tests (`TokenAnnotatorTest.cpp`)**: Verify that the new `CommentKind` classification is accurate, especially for parameter comments near complex syntax.
+*   **Manual spot checks**: Use `clang-format` from the command line to sanity-check behavior. Before landing, run a full format sweep on `llvm-project` with both default (`Leave`) and forced (`Always`/`Never`) settings to audit for any unintended changes.
\ No newline at end of file
diff --git a/.agent/TestingGuide.md b/.agent/TestingGuide.md
new file mode 100644
index 0000000000000..38302ec2f51e8
--- /dev/null
+++ b/.agent/TestingGuide.md
@@ -0,0 +1,831 @@
+---
+title: "Testing Guide"
+date: 2019-11-29T15:26:15Z
+draft: false
+weight: 40
+---
+
+{{< toc >}}
+
+## Quickstart commands
+
+These commands are explained below in more detail. All commands are run from the
+cmake build directory `build-mlir/`, after [building the project](/getting_started/).
+
+### Run all MLIR tests:
+
+```sh
+cmake --build . --target check-mlir
+```
+
+### Run integration tests (requires `-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`):
+
+```sh
+cmake --build . --target check-mlir-integration
+```
+
+### Run C++ unit tests:
+
+```sh
+bin/llvm-lit -v tools/mlir/test/Unit
+```
+
+### Run `lit` tests in a specific directory
+
+```sh
+bin/llvm-lit -v tools/mlir/test/Dialect/Arith
+```
+
+### Run a specific `lit` test file
+
+```sh
+bin/llvm-lit -v tools/mlir/test/Dialect/Polynomial/ops.mlir
+```
+
+## Test categories
+
+### `lit` and `FileCheck` tests
+
+[`FileCheck`](https://llvm.org/docs/CommandGuide/FileCheck.html) is a tool that
+"reads two files (one from standard input, and one specified on the command
+line) and uses one to verify the other." One file contains a set of `CHECK` tags
+that specify strings and patterns expected to appear in the other file. MLIR
+utilizes [`lit`](https://llvm.org/docs/CommandGuide/lit.html) to orchestrate the
+execution of tools like `mlir-opt` to produce an output, and `FileCheck` to
+verify different aspects of the IR—such as the output of a transformation pass.
+
+The source files of `lit`/`FileCheck` tests are organized within the `mlir`
+source tree under `mlir/test/`. Within this directory, tests are organized
+roughly mirroring `mlir/include/mlir/`, including subdirectories for `Dialect/`,
+`Transforms/`, `Conversion/`, etc.
+
+#### Example
+
+An example `FileCheck` test is shown below:
+
+```mlir
+// RUN: mlir-opt %s -cse | FileCheck %s
+
+// CHECK-LABEL: func.func @simple_constant
+func.func @simple_constant() -> (i32, i32) {
+  // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 1
+  // CHECK-NEXT: return %[[RESULT]], %[[RESULT]]
+
+  %0 = arith.constant 1 : i32
+  %1 = arith.constant 1 : i32
+  return %0, %1 : i32, i32
+}
+```
+
+A comment with `RUN` represents a `lit` directive specifying a command line
+invocation to run, with special substitutions like `%s` for the current file. A
+comment with `CHECK` represents a `FileCheck` directive to assert a string or
+pattern appears in the output.
+
+The above test asserts that, after running Common Subexpression Elimination
+(`-cse`), only one constant remains in the IR, and the sole SSA value is
+returned twice from the function.
+
+#### Build system details
+
+The main way to run all the tests mentioned above in a single invocation can be
+done using the `check-mlir` target:
+
+```sh
+cmake --build . --target check-mlir
+```
+
+Invoking the `check-mlir` target is roughly equivalent to running (from the
+build directory, after building):
+
+```shell
+./bin/llvm-lit tools/mlir/test
+```
+
+See the [Lit Documentation](https://llvm.org/docs/CommandGuide/lit.html) for a
+description of all options.
+
+Subsets of the testing tree can be invoked by passing a more specific path
+instead of `tools/mlir/test` above. Example:
+
+```shell
+./bin/llvm-lit tools/mlir/test/Dialect/Arith
+
+# Note that it is possible to test at the file granularity, but since these
+# files do not actually exist in the build directory, you need to know the
+# name.
+./bin/llvm-lit tools/mlir/test/Dialect/Arith/ops.mlir
+```
+
+Or for running all the C++ unit-tests:
+
+```shell
+./bin/llvm-lit tools/mlir/test/Unit
+```
+
+The C++ unit-tests can also be executed as individual binaries, which is
+convenient when iterating on cycles of rebuild-test:
+
+```shell
+# Rebuild the minimum amount of libraries needed for the C++ MLIRIRTests
+cmake --build . --target tools/mlir/unittests/IR/MLIRIRTests
+
+# Invoke the MLIRIRTest C++ Unit Test directly
+tools/mlir/unittests/IR/MLIRIRTests
+
+# It works for specific C++ unit-tests as well:
+LIT_OPTS="--filter=MLIRIRTests -a" cmake --build . --target check-mlir
+
+# Run just one specific subset inside the MLIRIRTests:
+tools/mlir/unittests/IR/MLIRIRTests --gtest_filter=OpPropertiesTest.Properties
+```
+
+Lit has a number of options that control test execution. Here are some of the
+most useful for development purposes:
+
+*   [`--filter=REGEXP`](https://llvm.org/docs/CommandGuide/lit.html#cmdoption-lit-filter) :
+    Only runs tests whose name matches the REGEXP. Can also be specified via the
+    `LIT_FILTER` environment variable.
+*   [`--filter-out=REGEXP`](https://llvm.org/docs/CommandGuide/lit.html#cmdoption-lit-filter-out) :
+    Filters out tests whose name matches the REGEXP. Can also be specified via
+    the `LIT_FILTER_OUT` environment variable.
+*   [`-a`](https://llvm.org/docs/CommandGuide/lit.html#cmdoption-lit-a) : Shows
+    all information (useful while iterating on a small set of tests).
+*   [`--time-tests`](https://llvm.org/docs/CommandGuide/lit.html#cmdoption-lit-time-tests) :
+    Prints timing statistics about slow tests and overall histograms.
+
+Any Lit options can be set in the `LIT_OPTS` environment variable. This is
+especially useful when using the build system target `check-mlir`.
+
+Examples:
+
+```
+# Only run tests that have "python" in the name and print all invocations.
+LIT_OPTS="--filter=python -a" cmake --build . --target check-mlir
+
+# Only run the array_attributes python test, using the LIT_FILTER mechanism.
+LIT_FILTER="python/ir/array_attributes" cmake --build . --target check-mlir
+
+# Run everything except for example and integration tests (which are both
+# somewhat slow).
+LIT_FILTER_OUT="Examples|Integrations" cmake --build . --target check-mlir
+```
+
+Note that the above use the generic cmake command for invoking the `check-mlir`
+target, but you can typically use the generator directly to be more concise
+(i.e. if configured for `ninja`, then `ninja check-mlir` can replace the `cmake
+--build . --target check-mlir` command). We use generic `cmake` commands in
+documentation for consistency, but being concise is often better for interactive
+workflows.
+
+### Diagnostic tests
+
+MLIR provides rich source location tracking that can be used to emit errors,
+warnings, etc. from anywhere throughout the codebase, which are jointly called
+*diagnostics*. Diagnostic tests assert that specific diagnostic messages are
+emitted for a given input program. These tests are useful in that they allow
+checking specific invariants of the IR without transforming or changing
+anything.
+
+Some examples of tests in this category are:
+
+-   Verifying invariants of operations
+-   Checking the expected results of an analysis
+-   Detecting malformed IR
+
+Diagnostic verification tests are written utilizing the
+[source manager verifier handler](../docs/Diagnostics#sourcemgr-diagnostic-verifier-handler),
+which is enabled via the `verify-diagnostics` flag in `mlir-opt`.
+
+An example .mlir test running under `mlir-opt` is shown below:
+
+```mlir
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// Expect an error on the same line.
+func.func @bad_branch() {
+  cf.br ^missing  // expected-error {{reference to an undefined block}}
+}
+
+// -----
+
+// Expect an error on an adjacent line.
+func.func @foo(%a : f32) {
+  // expected-error at +1 {{invalid predicate attribute specification: "foo"}}
+  %result = arith.cmpf "foo", %a, %a : f32
+  return
+}
+```
+
+### Integration tests
+
+Integration tests are `FileCheck` tests that verify functional correctness of
+MLIR code by running it, usually by means of JIT compilation using
+`mlir-cpu-runner` and runtime support libraries.
+
+Integration tests don't run by default. To enable them, set the
+`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON` flag during `cmake` configuration as
+described in [Getting Started](_index.md).
+
+```sh
+cmake -G Ninja ../llvm \
+   ... \
+   -DMLIR_INCLUDE_INTEGRATION_TESTS=ON \
+   ...
+```
+
+Now the integration tests run as part of regular testing.
+
+```sh
+cmake --build . --target check-mlir
+```
+
+To run only the integration tests, run the `check-mlir-integration` target.
+
+```sh
+cmake --build . --target check-mlir-integration
+```
+
+Note that integration tests are relatively expensive to run (primarily due to
+JIT compilation), and tend to be trickier to debug (with multiple compilation
+steps _integrated_, it usually takes a bit of triaging to find the root cause
+of a failure). We reserve e2e tests for cases that are hard to verify
+otherwise, e.g. when composing and testing complex compilation pipelines. In
+those cases, verifying run-time output tends to be easier then the checking
+e.g. LLVM IR with FileCheck. Lowering optimized `linalg.matmul` (with tiling
+and vectorization) is a good example. For less involved lowering pipelines or
+when there's almost 1-1 mapping between an Op and it's LLVM IR counterpart
+(e.g. `arith.cmpi` and LLVM IR `icmp` instruction),  regular unit tests are considered
+enough.
+
+The source files of the integration tests are organized within the `mlir` source
+tree by dialect (for example, `test/Integration/Dialect/Vector`).
+
+#### Hardware emulators
+
+The integration tests include some tests for targets that are not widely
+available yet, such as specific AVX512 features (like `vp2intersect`) and the
+Intel AMX instructions. These tests require an emulator to run correctly
+(lacking real hardware, of course). To enable these specific tests, first
+download and install the
+[Intel Emulator](https://software.intel.com/content/www/us/en/develop/articles/intel-software-development-emulator.html).
+Then, include the following additional configuration flags in the initial set up
+(X86Vector and AMX can be individually enabled or disabled), where `<path to
+emulator>` denotes the path to the installed emulator binary. `sh cmake -G Ninja
+../llvm \ ... \ -DMLIR_INCLUDE_INTEGRATION_TESTS=ON \
+-DMLIR_RUN_X86VECTOR_TESTS=ON \ -DMLIR_RUN_AMX_TESTS=ON \
+-DINTEL_SDE_EXECUTABLE=<path to emulator> \ ...` After this one-time set up, the
+tests run as shown earlier, but will now include the indicated emulated tests as
+well.
+
+### C++ Unit tests
+
+Unit tests are written using the
+[googletest](https://google.github.io/googletest/) framework and are located in
+the `mlir/unittests/` directory.
+
+## Contributor guidelines
+
+In general, all commits to the MLIR repository should include an accompanying
+test of some form. Commits that include no functional changes, such as API
+changes like symbol renaming, should be tagged with NFC (No Functional Changes).
+This signals to the reviewer why the change doesn't/shouldn't include a test.
+
+`lit` tests with `FileCheck` are the preferred method of testing in MLIR for
+non-erroneous output verification.
+
+Diagnostic tests are the preferred method of asserting error messages are output
+correctly. Every user-facing error message (e.g., `op.emitError()`) should be
+accompanied by a corresponding diagnostic test.
+
+When you cannot use the above, such as for testing a non-user-facing API like a
+data structure, then you may write C++ unit tests. This is preferred because the
+C++ APIs are not stable and subject to frequent refactoring. Using `lit` and
+`FileCheck` allows maintainers to improve the MLIR internals more easily.
+
+### FileCheck best practices
+
+FileCheck is an extremely useful utility, it allows for easily matching various
+parts of the output. This ease of use means that it becomes easy to write
+brittle tests that are essentially `diff` tests. FileCheck tests should be as
+self-contained as possible and focus on testing the minimal set of
+functionalities needed. Let's see an example:
+
+```mlir
+// RUN: mlir-opt %s -cse | FileCheck %s
+
+// CHECK-LABEL: func.func @simple_constant() -> (i32, i32)
+func.func @simple_constant() -> (i32, i32) {
+  // CHECK-NEXT: %result = arith.constant 1 : i32
+  // CHECK-NEXT: return %result, %result : i32, i32
+  // CHECK-NEXT: }
+
+  %0 = arith.constant 1 : i32
+  %1 = arith.constant 1 : i32
+  return %0, %1 : i32, i32
+}
+```
+
+The above example is another way to write the original example shown in the main
+[`lit` and `FileCheck` tests](#lit-and-filecheck-tests) section. There are a few
+problems with this test; below is a breakdown of the no-nos of this test to
+specifically highlight best practices.
+
+*   Tests should be self-contained.
+
+This means that tests should not test lines or sections outside of what is
+intended. In the above example, we see lines such as `CHECK-NEXT: }`. This line
+in particular is testing pieces of the Parser/Printer of FuncOp, which is
+outside of the realm of concern for the CSE pass. This line should be removed.
+
+*   Tests should be minimal, and only check what is absolutely necessary.
+
+This means that anything in the output that is not core to the functionality
+that you are testing should *not* be present in a CHECK line. This is a separate
+bullet just to highlight the importance of it, especially when checking against
+IR output.
+
+If we naively remove the unrelated `CHECK` lines in our source file, we may end
+up with:
+
+```mlir
+// CHECK-LABEL: func.func @simple_constant
+func.func @simple_constant() -> (i32, i32) {
+  // CHECK-NEXT: %result = arith.constant 1 : i32
+  // CHECK-NEXT: return %result, %result : i32, i32
+
+  %0 = arith.constant 1 : i32
+  %1 = arith.constant 1 : i32
+  return %0, %1 : i32, i32
+}
+```
+
+It may seem like this is a minimal test case, but it still checks several
+aspects of the output that are unrelated to the CSE transformation. Namely the
+result types of the `arith.constant` and `return` operations, as well the actual
+SSA value names that are produced. FileCheck `CHECK` lines may contain
+[regex statements](https://llvm.org/docs/CommandGuide/FileCheck.html#filecheck-regex-matching-syntax)
+as well as named
+[string substitution blocks](https://llvm.org/docs/CommandGuide/FileCheck.html#filecheck-string-substitution-blocks).
+Utilizing the above, we end up with the example shown in the main
+[FileCheck tests](#filecheck-tests) section.
+
+```mlir
+// CHECK-LABEL: func.func @simple_constant
+func.func @simple_constant() -> (i32, i32) {
+  /// Here we use a substitution variable as the output of the constant is
+  /// useful for the test, but we omit as much as possible of everything else.
+  // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 1
+  // CHECK-NEXT: return %[[RESULT]], %[[RESULT]]
+
+  %0 = arith.constant 1 : i32
+  %1 = arith.constant 1 : i32
+  return %0, %1 : i32, i32
+}
+```
+
+### Test Formatting Best Practices
+
+When adding new tests, strive to follow these two key rules:
+
+1. **Follow the existing naming and whitespace style.**
+   - This applies when modifying existing test files that follow a particular
+     convention, as it likely fits the context.
+2. **Consistently document the edge case being tested.**
+   - Clearly state what makes this test unique and how it complements other
+     similar tests.
+
+While the first rule extends LLVM’s general coding style to tests, the second
+may feel new. The goal is to improve:
+
+- **Test discoverability** – Well-documented tests make it easier to pair tests
+  with patterns and understand their purpose.
+- **Test consistency** – Consistent documentation and naming lowers cognitive
+  load and helps avoid duplication.
+
+A well-thought-out naming convention helps achieve all of the above.
+
+---
+
+#### Example: Improving Test Readability & Naming
+
+Consider these **three tests** that exercise `vector.maskedload -> vector.load`
+lowering under the `-test-vector-to-vector-lowering` flag:
+
+##### Before: Inconsistent & Hard to Differentiate
+
+```mlir
+// CHECK-LABEL:   func @maskedload_regression_1(
+//  CHECK-SAME:       %[[A0:.*]]: memref<16xf32>,
+//  CHECK-SAME:       %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[LOAD:.*]] = vector.load %[[A0]][%[[C]]]
+//  CHECK-SAME:     : memref<16xf32>, vector<16xf32>
+//       CHECK:   return %[[LOAD]] : vector<16xf32>
+func.func @maskedload_regression_1(
+    %arg0: memref<16xf32>,
+    %arg1: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0 : index
+
+  %vec_i1 = vector.constant_mask [16] : vector<16xi1>
+  %ld = vector.maskedload %arg0[%c0], %vec_i1, %arg1
+    : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+
+  return %ld : vector<16xf32>
+}
+
+// CHECK-LABEL:   func @maskedload_regression_2(
+//  CHECK-SAME:       %[[A0:.*]]: memref<16xi8>,
+//  CHECK-SAME:       %[[A1:.*]]: vector<16xi8>) -> vector<16xi8> {
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[LOAD:.*]] = vector.load %[[A0]][%[[C]]]
+//  CHECK-SAME:     : memref<16xi8>, vector<16xi8>
+//       CHECK:   return %[[LOAD]] : vector<16xi8>
+func.func @maskedload_regression_2(
+    %arg0: memref<16xi8>,
+    %arg1: vector<16xi8>) -> vector<16xi8> {
+  %c0 = arith.constant 0 : index
+
+  %vec_i1 = vector.constant_mask [16] : vector<16xi1>
+  %ld = vector.maskedload %arg0[%c0], %vec_i1, %arg1
+    : memref<16xi8>, vector<16xi1>, vector<16xi8> into vector<16xi8>
+
+  return %ld : vector<16xi8>
+}
+
+// CHECK-LABEL:   func @maskedload_regression_3(
+// CHECK-SAME:        %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:        %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+//      CHECK:    return %[[A1]] : vector<16xf32>
+func.func @maskedload_regression_3(
+    %arg0: memref<16xf32>,
+    %arg1: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0 : index
+
+  %vec_i1 = vector.constant_mask [0] : vector<16xi1>
+  %ld = vector.maskedload %arg0[%c0], %vec_i1, %arg1
+    : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+
+  return %ld : vector<16xf32>
+}
+```
+
+While all examples test `vector.maskedload` -> `vector.load lowering`, it is
+difficult to tell their actual differences.
+
+##### After Step 1 (Introduce Consistent Variable Names)
+
+To reduce cognitive load, use consistent names across MLIR and FileCheck (e.g.,
+`%arg0` and `A0` above are not consistent). Also, instead of using generic
+names like `%arg0` or `%vec_i1`, encode some additional context by using names
+from existing documentation. For example from the Op documentation,
+[`vector.maskedload`](https://mlir.llvm.org/docs/Dialects/Vector/#vectormaskedload-vectormaskedloadop),
+in this case, you can use `%base`, `%mask` and `%pass_thru`.
+
+```mlir
+// CHECK-LABEL:   func @maskedload_regression_1(
+//  CHECK-SAME:       %[[BASE:.*]]: memref<16xf32>,
+//  CHECK-SAME:       %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
+// (...)
+func.func @maskedload_regression_1(
+    %base: memref<16xf32>,
+    %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  // (...)
+  %mask = vector.constant_mask [16] : vector<16xi1>
+  %ld = vector.maskedload %base[%c0], %mask, %pass_thru (...)
+  // (...)
+}
+
+// CHECK-LABEL:   func @maskedload_regression_2(
+//  CHECK-SAME:       %[[BASE:.*]]: memref<16xi8>,
+//  CHECK-SAME:       %[[PASS_THRU:.*]]: vector<16xi8>) -> vector<16xi8> {
+// (...)
+func.func @maskedload_regression_2(
+    %base: memref<16xi8>,
+    %pass_thru: vector<16xi8>) -> vector<16xi8> {
+  // (...)
+  %mask = vector.constant_mask [16] : vector<16xi1>
+  %ld = vector.maskedload %base[%c0], %mask, %pass_thru (...)
+  // (...)
+}
+
+// CHECK-LABEL:   func @maskedload_regression_3(
+//  CHECK-SAME:       %[[BASE:.*]]: memref<16xf32>,
+//  CHECK-SAME:       %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
+// (...)
+func.func @maskedload_regression_3(
+    %base: memref<16xf32>,
+    %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  // (...)
+  %mask = vector.constant_mask [0] : vector<16xi1>
+  %ld = vector.maskedload %base[%c0], %mask, %base (...)
+  // (...)
+}
+```
+
+##### After Step 2 (Improve Test Naming)
+
+Instead of using "regression" (which does not add unique information), rename
+tests based on key attributes:
+
+* All examples test the `vector.maskedload` to `vector.load` lowering.
+* The first test uses a _dynamically_ shaped `memref`, while the others use
+  _static_ shapes.
+* The mask in the first two examples is "all true" (`vector.constant_mask
+  [16]`), while it is "all false" (`vector.constant_mask [0]`) in the third
+  example.
+* The first and the third tests use `i32` elements, whereas the second uses
+  `i8`.
+
+This suggests the following naming scheme:
+* `@maskedload_to_load_{static|dynamic}_{i32|i8}_{all_true|all_false}`.
+
+Below are the updated names:
+
+```mlir
+// CHECK-LABEL:   func @maskedload_to_load_dynamic_i32_all_true(
+// (...)
+func.func @maskedload_to_load_dynamic_i32_all_true(...) -> vector<16xf32> {
+  // (...)
+}
+
+// CHECK-LABEL:   func @maskedload_to_load_static_i8_all_true(
+// (...)
+func.func @maskedload_to_load_static_i8_all_true(...) -> vector<16xi8> {
+  // (...)
+}
+
+// CHECK-LABEL:   func @maskedload_to_load_static_i32_all_false(
+// (...)
+func.func @maskedload_to_load_static_i32_all_false(...) -> vector<16xf32> {
+  // (...)
+}
+```
+
+##### After Step 3 (Add The Newly Identified Missing Case)
+
+Step 2 made it possible to see that there is a case which is not tested:
+
+* A mask that is neither "all true" nor "all false".
+
+Unlike the existing cases, this mask must be preserved. In this scenario,
+`vector.load` is not the right abstraction. Thus, no lowering should occur:
+
+```mlir
+// CHECK-LABEL:   func @negative_maskedload_to_load_static_i32_mixed(
+// CHECK-SAME:        %[[BASE:.*]]: memref<16xf32>,
+// CHECK-SAME:        %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
+//      CHECK:    vector.maskedload
+func.func @negative_maskedload_to_load_static_i32_mixed(
+    %base: memref<16xf32>,
+    %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0 : index
+  %mask = vector.constant_mask [4] : vector<16xi1>
+
+  %ld = vector.maskedload %base[%c0], %mask, %pass_thru
+    : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+
+  return %ld : vector<16xf32>
+}
+```
+
+The `negative_` prefix indicates that this test should fail to lower, as the
+pattern should not match.
+
+##### Test Naming Convention
+To summarize, here is the naming convention used in the examples above:
+
+* `@{negative_}?maskedload_to_load_{static|dynamic}_{i32|i8}_{all_true|all_false|mixed}`.
+
+The exact format may vary depending on context. However:
+* **Avoid using suffixes** (e.g., `_fail`) to indicate negative tests — prefixes like
+  `negative_` are easier to spot and grep for.
+* Whatever naming convention you choose, **apply it consistently** throughout
+  the test suite.
+
+**Note:** In some cases, a prefix other than `negative_` might be more
+appropriate. For instance, in "folding" tests where a pattern is expected not
+to apply, using `no_` can be a more concise and equally clear alternative —
+e.g., `@no_fold_<case>_<subcase>.`
+
+#### What if there is no pre-existing style to follow?
+
+If you are adding a new test file, you can use other test files in the same
+directory as inspiration.
+
+If the test file you are modifying lacks a clear style and instead has mixed,
+inconsistent styles, try to identify the dominant one and follow it. Even
+better, consider refactoring the file to adopt a single, consistent style —
+this helps improve our overall testing quality. Refactoring is also encouraged
+when the existing style could be improved.
+
+In many cases, it is best to create a separate PR for test refactoring to
+reduce per-PR noise. However, this depends on the scale of changes — reducing
+PR traffic is also important. Work with reviewers to use your judgment and
+decide the best approach.
+
+Alternatively, if you defer refactoring, consider creating a GitHub issue and
+adding a TODO in the test file linking to it.
+
+When creating a new naming convention, keep these points in mind:
+
+* **Write Orthogonal Tests**
+If naming is difficult then the tests may be lacking a clear purpose. A good
+rule of thumb is to avoid testing the same thing repeatedly. Before writing
+tests, define clear categories to cover (e.g., number of loops, data types).
+This often leads to a natural naming scheme—for example: `@loop_depth_2_i32`.
+
+* **What vs Why**
+Test names should reflect _what_ is being tested, not _why_.
+
+Encoding _why_ in test names can lead to overly long and complex names.
+Instead, add inline comments where needed.
+
+#### Do not forget the common sense
+
+Always apply common sense when naming functions and variables. Encoding too
+much information in names makes the tests less readable and less maintainable.
+
+Trust your judgment. When in doubt, consult your "future self": _"Will this still
+make sense to me six months from now?_"
+
+#### Final Points - Key Principles
+
+The above approach is just an example. It may not fit your use case perfectly,
+so feel free to adapt it as needed.  Key principles to follow:
+
+* Make tests self-documenting.
+* Follow existing conventions.
+
+These principles make tests easier to discover and maintain. For you, "future
+you", and the rest of the MLIR community.
+
+### Test Documentation Best Practices
+
+In addition to following good naming and formatting conventions, please
+document your tests with comments. Focus on explaining **why** since the
+**what** is usually clear from the code itself.
+
+As an example, consider this test that uses the
+`TransferWritePermutationLowering` pattern:
+
+
+```mlir
+/// Even with out-of-bounds accesses, it is safe to apply this pattern as it
+/// does not modify which memory location is being accessed.
+
+// CHECK-LABEL:   func.func @xfer_write_minor_identity_transposed_out_of_bounds
+//  CHECK-SAME:      %[[VEC:.*]]: vector<4x8xi16>
+//  CHECK-SAME:      %[[MEM:.*]]: memref<2x2x?x?xi16>
+//  CHECK-SAME:      %[[IDX:.*]]: index)
+//       CHECK:      %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0]
+//  CHECK_SAME:        : vector<4x8xi16> to vector<8x4xi16>
+
+/// Expect the in_bounds attribute to be preserved. However, since we don't
+/// print it when all flags are "false", it should not appear in the output.
+/// CHECK-NOT:       in_bounds
+
+// CHECK:           vector.transfer_write
+
+/// The permutation map was replaced with vector.transpose
+// CHECK-NOT:       permutation_map
+
+// CHECK-SAME:        %[[TR]], %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]
+// CHECK-SAME:        : vector<8x4xi16>, memref<2x2x?x?xi16>
+func.func @xfer_write_minor_identity_transposed_out_of_bounds(
+    %vec: vector<4x8xi16>,
+    %mem: memref<2x2x?x?xi16>,
+    %idx: index) {
+
+  vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
+    in_bounds = [false, false],
+    permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+  } : vector<4x8xi16>, memref<2x2x?x?xi16>
+
+  return
+}
+```
+
+The comments in the example above document two non-obvious behaviors:
+
+* _Why_ is the `permutation_map` attribute missing from the output?
+* _Why_ is the `in_bounds` attribute missing from the output?
+
+
+#### How to Identify What Needs Documentation?
+Think of yourself six months from now and ask: _"What might be difficult to
+understand without comments?"_
+
+If you expect something to be tricky for "future-you", it’s likely to be tricky
+for others encountering the test for the first time.
+
+#### Making Tests Self-Documenting
+We can improve documentation further by:
+* clarifying what pattern is being tested,
+* providing high-level reasoning, and
+* consolidating shared comments.
+
+For example:
+
+```mlir
+///--------------------------------------------------------------------------------
+/// [Pattern: TransferWritePermutationLowering]
+///
+/// IN: vector.transfer_write (_transposed_ minor identity permutation map)
+/// OUT: vector.transpose + vector.transfer_write (minor identity permutation map)
+///
+/// Note: `permutation_map` from the input Op is replaced with the newly
+/// inserted vector.traspose Op.
+///--------------------------------------------------------------------------------
+// CHECK-LABEL:   func.func @xfer_write_minor_identity_transposed
+//       (...)
+//       CHECK:      %[[TR:.*]] = vector.transpose (...)
+//       CHECK:      vector.transfer_write %[[TR]] (...)
+//       (...)
+```
+
+The example above documents:
+* The transformation pattern being tested.
+* The key logic behind the transformation.
+* The expected change in output.
+
+
+#### Documenting the "What"
+You should always document why, but documenting what is also valid and
+encouraged in cases where:
+
+* The test output is long and complex.
+* The tested logic is non-trivial and/or involves multiple transformations.
+
+For example, in this test for Linalg convolution vectorization, comments are
+used to document high-level steps (original FileCheck "check" lines have been
+trimmed for brevity):
+
+```mlir
+func.func @conv1d_nwc_4x2x8_memref(
+    %input: memref<4x6x3xf32>,
+    %filter: memref<1x3x8xf32>,
+    %output: memref<4x2x8xf32>) {
+  linalg.conv_1d_nwc_wcf
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>)
+    outs(%output : memref<4x2x8xf32>)
+  return
+}
+
+//      CHECK: func @conv1d_nwc_4x2x8_memref
+// CHECK-SAME: %[[INPUT:.+]]: memref<4x6x3xf32>
+// CHECK-SAME: %[[FILTER:.+]]: memref<1x3x8xf32>,
+// CHECK-SAME: %[[OUTPUT:.+]]: memref<4x2x8xf32>
+
+/// Read the whole data in one shot.
+//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]]
+//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]]
+//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]]
+
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+
+//      CHECK:    %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0]
+
+//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+
+/// w == 0, kw == 0
+//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract
+// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
+
+/// w == 1, kw == 0
+//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract
+// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
+
+/// w == 0, kw == 0
+//      CHECK:   %[[RES_0:.+]] = vector.insert_strided_slice
+// CHECK-SAME:    %[[CONTRACT_0]], %[[V_OUTPUT_R]]
+/// w == 1, kw == 0
+//      CHECK:   %[[RES_1:.+]] = vector.insert_strided_slice
+// CHECK-SAME:    %[[CONTRACT_1]], %[[RES_0]]
+
+/// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[RES_1]], %[[OUTPUT]]
+```
+
+Though the comments document _what_ is happening (e.g., "Write the result back
+in one shot"), some variables — like `w` and `kw` — are not explained. This is
+intentional - their purpose becomes clear when studying the corresponding
+Linalg vectorizer implementation (or, when analysing how
+`linalg.conv_1d_nwc_wcf` works).
+
+Comments help you understand code, they do not replace the need to read it.
+Comments guide the reader, they do not repeat what the code already says.
+
+#### Final Points - Key Principles
+Below are key principles to follow when documenting tests:
+* Always document _why_, document _what_ if you need to (e.g. the underlying
+	logic is non-trivial).
+* Use block comments for higher-level comments (e.g. to describe the patterns
+	being tested).
+* Think about maintainability - comments should help future developers (which
+	includes you) understand tests at a glance.
+* Avoid over-explaining. Comments should assist, not replace reading the code.
+* Avoid relative time references. Terms like "previously", "currently", or "newly" become outdated quickly. Focus on describing the invariant or the behavior being tested, rather than the history of the bug.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir b/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir
index 937c855b86b50..88ea334bb9349 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s
 
 // Test cases specifically for sibling fusion. Note that sibling fusion test
 // cases also exist in loop-fusion*.mlir.
@@ -21,3 +21,24 @@ func.func @disjoint_stores(%0: memref<8xf32>) {
   // CHECK-NOT: affine.for
   return
 }
+
+// -----
+
+// CHECK-LABEL: func.func @sibling_fusion_shape_mismatch
+// CHECK: affine.for %{{.*}} = 0 to 10 {
+// CHECK:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32>
+// CHECK:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<8xf32>
+// CHECK:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32>
+
+/// Read-After-Read dependence does not require vector shape alignment.
+func.func @sibling_fusion_shape_mismatch(%src: memref<10x16xf32>) {
+  affine.for %i = 0 to 10 {
+    %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32>
+  }
+
+  affine.for %i = 0 to 10 {
+    %wide = affine.vector_load %src[%i, 8] : memref<10x16xf32>, vector<8xf32>
+    %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32>
+  }
+  return
+}
diff --git a/mlir/test/Dialect/Affine/loop-fusion-vector.mlir b/mlir/test/Dialect/Affine/loop-fusion-vector.mlir
index f5dd13c36f8d3..6ea9d1302ed06 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-vector.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-vector.mlir
@@ -1,37 +1,43 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING
 
-// CHECK-LABEL: func.func @skip_fusing_mismatched_vectors
+// CHECK-LABEL: func.func @negative_fusion_producer_consumer_shape_mismatch
 // CHECK: affine.for %{{.*}} = 0 to 8 {
 // CHECK:   affine.vector_store {{.*}} : memref<64x512xf32>, vector<64x64xf32>
-// CHECK: }
 // CHECK: affine.for %{{.*}} = 0 to 8 {
 // CHECK:   affine.vector_load {{.*}} : memref<64x512xf32>, vector<64x512xf32>
-// CHECK: }
-func.func @skip_fusing_mismatched_vectors(%a: memref<64x512xf32>, %b: memref<64x512xf32>, %c: memref<64x512xf32>, %d: memref<64x4096xf32>, %e: memref<64x4096xf32>) {
+
+/// Mismatched vector shapes prevent valid fusion due to element misalignment.
+func.func @negative_fusion_producer_consumer_shape_mismatch(
+    %arg0: memref<64x512xf32>,
+    %arg1: memref<64x512xf32>,
+    %arg2: memref<64x512xf32>,
+    %arg3: memref<64x4096xf32>) {
   affine.for %j = 0 to 8 {
-    %lhs = affine.vector_load %a[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
-    %rhs = affine.vector_load %b[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+    %lhs = affine.vector_load %arg0[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+    %rhs = affine.vector_load %arg1[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
     %res = arith.addf %lhs, %rhs : vector<64x64xf32>
-    affine.vector_store %res, %c[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+    affine.vector_store %res, %arg2[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
   }
 
   affine.for %j = 0 to 8 {
-    %lhs = affine.vector_load %c[0, 0] : memref<64x512xf32>, vector<64x512xf32>
-    %rhs = affine.vector_load %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+    %lhs = affine.vector_load %arg2[0, 0] : memref<64x512xf32>, vector<64x512xf32>
+    %rhs = affine.vector_load %arg3[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
     %res = arith.subf %lhs, %rhs : vector<64x512xf32>
-    affine.vector_store %res, %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+    affine.vector_store %res, %arg3[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
   }
   return
 }
 
 // -----
 
-// CHECK-LABEL: func.func @vector_private_memref
+// CHECK-LABEL: func.func @fusion_private_memref_vector_size
 // CHECK: memref.alloc() : memref<1x64xf32>
 // CHECK-NOT: memref<1x1xf32>
 // CHECK: affine.vector_store {{.*}} : memref<1x64xf32>, vector<64xf32>
-func.func @vector_private_memref(%src: memref<10x64xf32>, %dst: memref<10x64xf32>) {
+
+/// Private buffer must accommodate vector shape (1x64), not just scalar shape
+/// (1x1).
+func.func @fusion_private_memref_vector_size(%src: memref<10x64xf32>, %dst: memref<10x64xf32>) {
   %tmp = memref.alloc() : memref<10x64xf32>
   affine.for %i = 0 to 10 {
     %vec = affine.vector_load %src[%i, 0] : memref<10x64xf32>, vector<64xf32>
@@ -47,15 +53,16 @@ func.func @vector_private_memref(%src: memref<10x64xf32>, %dst: memref<10x64xf32
 
 // -----
 
-// CHECK-LABEL: func.func @fuse_scalar_vector
+// CHECK-LABEL: func.func @fusion_scalar_producer_vector_consumer
 // CHECK: %[[TMP:.*]] = memref.alloc() : memref<64xf32>
 // CHECK: affine.for %[[I:.*]] = 0 to 16 {
 // CHECK:   %[[S0:.*]] = affine.load %[[SRC:.*]][%[[I]] * 4] : memref<64xf32>
 // CHECK:   affine.store %[[S0]], %[[TMP]][%[[I]] * 4] : memref<64xf32>
 // CHECK:   %[[V:.*]] = affine.vector_load %[[TMP]][%[[I]] * 4] : memref<64xf32>, vector<4xf32>
 // CHECK:   affine.vector_store %[[V]], %[[DST:.*]][%[[I]] * 4] : memref<64xf32>, vector<4xf32>
-// CHECK: }
-func.func @fuse_scalar_vector(%src: memref<64xf32>, %dst: memref<64xf32>) {
+
+/// Scalar-to-vector fusion requires correct intermediate buffer alloc.
+func.func @fusion_scalar_producer_vector_consumer(%src: memref<64xf32>, %dst: memref<64xf32>) {
   %tmp = memref.alloc() : memref<64xf32>
   affine.for %i = 0 to 16 {
     %s0 = affine.load %src[%i * 4] : memref<64xf32>
@@ -75,23 +82,3 @@ func.func @fuse_scalar_vector(%src: memref<64xf32>, %dst: memref<64xf32>) {
   memref.dealloc %tmp : memref<64xf32>
   return
 }
-
-// -----
-
-// SIBLING-LABEL: func.func @sibling_vector_mismatch
-// SIBLING: affine.for %{{.*}} = 0 to 10 {
-// SIBLING:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32>
-// SIBLING:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<8xf32>
-// SIBLING:   affine.vector_load %{{.*}} : memref<10x16xf32>, vector<4xf32>
-// SIBLING: }
-func.func @sibling_vector_mismatch(%src: memref<10x16xf32>) {
-  affine.for %i = 0 to 10 {
-    %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32>
-  }
-
-  affine.for %i = 0 to 10 {
-    %wide = affine.vector_load %src[%i, 8] : memref<10x16xf32>, vector<8xf32>
-    %vec = affine.vector_load %src[%i, 0] : memref<10x16xf32>, vector<4xf32>
-  }
-  return
-}



More information about the llvm-commits mailing list