[Mlir-commits] [mlir] [mlir][vector] MLIR SLP vectorizer (PR #140469)

Ivan Butygin llvmlistbot at llvm.org
Sun Jun 1 03:59:33 PDT 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/140469

>From aa11ef8eda1f8392b32b453f96981fff66514ca3 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:01:41 +0200
Subject: [PATCH 01/52] stubs

---
 .../mlir/Dialect/Vector/Transforms/Passes.h   |  3 +
 .../mlir/Dialect/Vector/Transforms/Passes.td  | 12 ++++
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 +
 .../Vector/Transforms/SLPVectorizer.cpp       | 63 +++++++++++++++++++
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 34 ++++++++++
 5 files changed, 113 insertions(+)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
 create mode 100644 mlir/test/Dialect/Vector/slp-vectorize.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..43112f084dc60 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -25,6 +25,9 @@ std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
     VectorMultiReductionLowering option =
         VectorMultiReductionLowering::InnerParallel);
 
+/// Creates a pass that implements the SLP vectorizer.
+std::unique_ptr<Pass> createSLPVectorizerPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..94ccd61cb5170 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,16 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
   ];
 }
 
+def SLPVectorizer : Pass<"slp-vectorizer", "ModuleOp"> {
+  let summary = "SLP Vectorizer Pass";
+  let description = [{
+    This pass implements the SLP (Superword Level Parallelism) vectorizer.
+    It detects consecutive operations that can be put together into vector
+    operations. The pass works bottom-up, across basic blocks, in search of
+    scalars to combine.
+  }];
+  let constructor = "mlir::vector::createSLPVectorizerPass()";
+  let dependentDialects = ["mlir::vector::VectorDialect"];
+}
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..37333b739bd86 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorStep.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
+  SLPVectorizer.cpp
   SubsetOpInterfaceImpl.cpp
   VectorDistribute.cpp
   VectorDropLeadUnitDim.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
new file mode 100644
index 0000000000000..e9f3b12bc7461
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -0,0 +1,63 @@
+//===- SLPVectorizer.cpp - SLP Vectorizer Pass ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the SLP vectorizer pass for MLIR. The pass attempts to
+// combine similar independent operations into vector operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "slp-vectorizer"
+
+namespace mlir {
+namespace vector {
+#define GEN_PASS_DEF_SLPVECTORIZER
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace vector
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// This pass implements the SLP vectorizer. It detects consecutive operations
+/// that can be put together into vector operations. The pass works bottom-up,
+/// across basic blocks, in search of scalars to combine.
+struct SLPVectorizerPass
+    : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void SLPVectorizerPass::runOnOperation() {
+  Operation *op = getOperation();
+  MLIRContext *context = &getContext();
+
+  // TODO: Implement SLP vectorization logic
+  // 1. Find candidate operations for vectorization
+  // 2. Build vectorization trees
+  // 3. Perform vectorization if profitable
+  // 4. Clean up scalar operations
+
+  LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass\n");
+  llvm::errs() << "Running SLP Vectorizer pass\n";
+}
+
+std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
+  return std::make_unique<SLPVectorizerPass>();
+}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
new file mode 100644
index 0000000000000..31543f3a76b2e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -test-slp-vectorization | FileCheck %s
+
+// CHECK-LABEL: func @basic_slp
+func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK: vector.transfer_read
+  // CHECK: arith.addi
+  // CHECK: vector.transfer_write
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+  %8 = arith.addi %0, %4 : i32
+  %9 = arith.addi %1, %5 : i32
+  %10 = arith.addi %2, %6 : i32
+  %11 = arith.addi %3, %7 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
+
+  return
+}

>From 36bd924098956fa390972489ad3dda836da73135 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:20:34 +0200
Subject: [PATCH 02/52] something working

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 113 +++++++++++++++++-
 mlir/test/Dialect/Vector/slp-vectorize.mlir   |   2 +-
 2 files changed, 108 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e9f3b12bc7461..b696f36c82eee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
@@ -34,27 +35,127 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
+/// A group of consecutive memory operations of the same type (load or store)
+/// that can potentially be vectorized together.
+struct MemoryOpGroup {
+  enum class Type { Load, Store };
+  Type type;
+  SmallVector<Operation *> ops;
+
+  MemoryOpGroup(Type t) : type(t) {}
+
+  bool isLoadGroup() const { return type == Type::Load; }
+  bool isStoreGroup() const { return type == Type::Store; }
+
+  size_t size() const { return ops.size(); }
+  bool empty() const { return ops.empty(); }
+};
+
 /// This pass implements the SLP vectorizer. It detects consecutive operations
 /// that can be put together into vector operations. The pass works bottom-up,
 /// across basic blocks, in search of scalars to combine.
 struct SLPVectorizerPass
     : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
   void runOnOperation() override;
+
+private:
+  /// Collect all memory operations in the block into groups.
+  /// Each group contains either all loads or all stores, uninterrupted by
+  /// operations of the other type.
+  SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
 };
 
 } // namespace
 
+SmallVector<MemoryOpGroup>
+SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
+  SmallVector<MemoryOpGroup> groups;
+  MemoryOpGroup *currentGroup = nullptr;
+
+  LLVM_DEBUG(llvm::dbgs() << "Scanning block for memory operations...\n");
+
+  for (Operation &op : block) {
+    LLVM_DEBUG(llvm::dbgs() << "Checking operation: " << op.getName() << "\n");
+
+    // Skip non-memory operations
+    if (!isa<memref::LoadOp, memref::StoreOp>(op)) {
+      LLVM_DEBUG(llvm::dbgs() << "  Not a memory operation\n");
+      continue;
+    }
+
+    bool isLoad = isa<memref::LoadOp>(op);
+    MemoryOpGroup::Type type =
+        isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
+
+    LLVM_DEBUG(llvm::dbgs()
+               << "  Found " << (isLoad ? "load" : "store") << " operation\n");
+
+    // Start a new group if:
+    // 1. We don't have a current group, or
+    // 2. The current operation is a different type than the current group
+    if (!currentGroup || currentGroup->type != type) {
+      LLVM_DEBUG(llvm::dbgs() << "  Starting new group\n");
+      groups.emplace_back(type);
+      currentGroup = &groups.back();
+    }
+
+    currentGroup->ops.push_back(&op);
+  }
+
+  // Remove empty groups
+  groups.erase(std::remove_if(groups.begin(), groups.end(),
+                              [](const MemoryOpGroup &g) { return g.empty(); }),
+               groups.end());
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "Found " << groups.size() << " memory operation groups:\n";
+    for (const auto &group : groups) {
+      llvm::dbgs() << "  Group type: "
+                   << (group.isLoadGroup() ? "Load" : "Store")
+                   << ", size: " << group.size() << "\n";
+    }
+  });
+
+  return groups;
+}
+
 void SLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
   MLIRContext *context = &getContext();
 
-  // TODO: Implement SLP vectorization logic
-  // 1. Find candidate operations for vectorization
-  // 2. Build vectorization trees
-  // 3. Perform vectorization if profitable
-  // 4. Clean up scalar operations
+  LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass on operation: "
+                          << op->getName() << "\n");
+
+  // Process each function in the module
+  for (Region &region : op->getRegions()) {
+    for (Block &block : region) {
+      for (Operation &op : block) {
+        // If this is a function, process its body
+        if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Processing function: " << funcOp.getName() << "\n");
+
+          // Process each block in the function
+          for (Block &funcBlock : funcOp.getBody()) {
+            // Collect memory operation groups
+            SmallVector<MemoryOpGroup> groups =
+                collectMemoryOpGroups(funcBlock);
+
+            LLVM_DEBUG({
+              llvm::dbgs() << "Found " << groups.size()
+                           << " memory operation groups:\n";
+              for (const auto &group : groups) {
+                llvm::dbgs() << "  Group type: "
+                             << (group.isLoadGroup() ? "Load" : "Store")
+                             << ", size: " << group.size() << "\n";
+              }
+            });
+          }
+        }
+      }
+    }
+  }
 
-  LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass\n");
   llvm::errs() << "Running SLP Vectorizer pass\n";
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 31543f3a76b2e..a07dd05dd16aa 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-slp-vectorization | FileCheck %s
+// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
 
 // CHECK-LABEL: func @basic_slp
 func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {

>From 5b220a7a43b2c1c32bed04e662946aecd5ee29b1 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:29:20 +0200
Subject: [PATCH 03/52] block walk

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 64 ++++---------------
 1 file changed, 11 insertions(+), 53 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index b696f36c82eee..bec5f9d90b21b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -72,29 +72,19 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
   SmallVector<MemoryOpGroup> groups;
   MemoryOpGroup *currentGroup = nullptr;
 
-  LLVM_DEBUG(llvm::dbgs() << "Scanning block for memory operations...\n");
-
   for (Operation &op : block) {
-    LLVM_DEBUG(llvm::dbgs() << "Checking operation: " << op.getName() << "\n");
-
     // Skip non-memory operations
-    if (!isa<memref::LoadOp, memref::StoreOp>(op)) {
-      LLVM_DEBUG(llvm::dbgs() << "  Not a memory operation\n");
+    if (!isa<memref::LoadOp, memref::StoreOp>(op))
       continue;
-    }
 
     bool isLoad = isa<memref::LoadOp>(op);
     MemoryOpGroup::Type type =
         isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
 
-    LLVM_DEBUG(llvm::dbgs()
-               << "  Found " << (isLoad ? "load" : "store") << " operation\n");
-
     // Start a new group if:
     // 1. We don't have a current group, or
     // 2. The current operation is a different type than the current group
     if (!currentGroup || currentGroup->type != type) {
-      LLVM_DEBUG(llvm::dbgs() << "  Starting new group\n");
       groups.emplace_back(type);
       currentGroup = &groups.back();
     }
@@ -107,15 +97,6 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
                               [](const MemoryOpGroup &g) { return g.empty(); }),
                groups.end());
 
-  LLVM_DEBUG({
-    llvm::dbgs() << "Found " << groups.size() << " memory operation groups:\n";
-    for (const auto &group : groups) {
-      llvm::dbgs() << "  Group type: "
-                   << (group.isLoadGroup() ? "Load" : "Store")
-                   << ", size: " << group.size() << "\n";
-    }
-  });
-
   return groups;
 }
 
@@ -123,40 +104,17 @@ void SLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
   MLIRContext *context = &getContext();
 
-  LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass on operation: "
-                          << op->getName() << "\n");
-
-  // Process each function in the module
-  for (Region &region : op->getRegions()) {
-    for (Block &block : region) {
-      for (Operation &op : block) {
-        // If this is a function, process its body
-        if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
-          LLVM_DEBUG(llvm::dbgs()
-                     << "Processing function: " << funcOp.getName() << "\n");
-
-          // Process each block in the function
-          for (Block &funcBlock : funcOp.getBody()) {
-            // Collect memory operation groups
-            SmallVector<MemoryOpGroup> groups =
-                collectMemoryOpGroups(funcBlock);
-
-            LLVM_DEBUG({
-              llvm::dbgs() << "Found " << groups.size()
-                           << " memory operation groups:\n";
-              for (const auto &group : groups) {
-                llvm::dbgs() << "  Group type: "
-                             << (group.isLoadGroup() ? "Load" : "Store")
-                             << ", size: " << group.size() << "\n";
-              }
-            });
-          }
-        }
-      }
-    }
-  }
+  // Walk all blocks recursively
+  op->walk([&](Block *block) {
+    LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
+                            << block->getParentOp()->getName() << "\n");
 
-  llvm::errs() << "Running SLP Vectorizer pass\n";
+    // Collect memory operation groups
+    SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
+
+    LLVM_DEBUG(llvm::dbgs() << "Found " << groups.size()
+                            << " memory operation groups in block\n");
+  });
 }
 
 std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {

>From a1a52f482d4ab546e84a44d7e81aaab0941e772a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:10:20 +0200
Subject: [PATCH 04/52] contiguous groups

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 106 +++++++++++++++++-
 1 file changed, 104 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index bec5f9d90b21b..f46dc71537ef3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -51,6 +51,96 @@ struct MemoryOpGroup {
   bool empty() const { return ops.empty(); }
 };
 
+// Extract contiguous groups from a MemoryOpGroup
+SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
+  SmallVector<MemoryOpGroup> result;
+  if (group.ops.empty())
+    return result;
+
+  // Keep track of which operations we've processed
+  DenseSet<Operation *> processedOps;
+
+  // Process each operation
+  for (Operation *op : group.ops) {
+    // Skip if we've already processed this operation
+    if (processedOps.contains(op))
+      continue;
+
+    // Get base and index of current operation
+    Value base;
+    int64_t index = -1;
+    if (group.isLoadGroup()) {
+      auto loadOp = cast<memref::LoadOp>(op);
+      if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
+        index = *value;
+        base = loadOp.getMemRef();
+      }
+    } else {
+      auto storeOp = cast<memref::StoreOp>(op);
+      if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
+        index = *value;
+        base = storeOp.getMemRef();
+      }
+    }
+    if (index == -1)
+      continue;
+
+    // Start a new group with this operation
+    result.emplace_back(group.type);
+    MemoryOpGroup &currentGroup = result.back();
+    currentGroup.ops.push_back(op);
+    processedOps.insert(op);
+
+    LLVM_DEBUG(llvm::dbgs() << "Starting new group at base " << base
+                            << " index " << index << "\n");
+
+    // Try to find operations with adjacent indices
+    bool foundMore;
+    do {
+      foundMore = false;
+      // Look for operations with index+1
+      for (Operation *otherOp : group.ops) {
+        if (processedOps.contains(otherOp))
+          continue;
+
+        Value otherBase;
+        int64_t otherIndex = -1;
+        if (group.isLoadGroup()) {
+          auto loadOp = cast<memref::LoadOp>(otherOp);
+          if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
+            otherIndex = *value;
+            otherBase = loadOp.getMemRef();
+          }
+        } else {
+          auto storeOp = cast<memref::StoreOp>(otherOp);
+          if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
+            otherIndex = *value;
+            otherBase = storeOp.getMemRef();
+          }
+        }
+
+        // Check if this operation has the same base and adjacent index
+        if (otherIndex != -1 && otherBase == base &&
+            otherIndex == currentGroup.ops.size()) {
+          currentGroup.ops.push_back(otherOp);
+          processedOps.insert(otherOp);
+          foundMore = true;
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Added operation with index " << otherIndex << "\n");
+          break;
+        }
+      }
+    } while (foundMore);
+  }
+
+  // Remove empty groups
+  result.erase(std::remove_if(result.begin(), result.end(),
+                              [](const MemoryOpGroup &g) { return g.empty(); }),
+               result.end());
+
+  return result;
+}
+
 /// This pass implements the SLP vectorizer. It detects consecutive operations
 /// that can be put together into vector operations. The pass works bottom-up,
 /// across basic blocks, in search of scalars to combine.
@@ -112,8 +202,20 @@ void SLPVectorizerPass::runOnOperation() {
     // Collect memory operation groups
     SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
 
-    LLVM_DEBUG(llvm::dbgs() << "Found " << groups.size()
-                            << " memory operation groups in block\n");
+    // Process each group to find contiguous sequences
+    for (const auto &group : groups) {
+      SmallVector<MemoryOpGroup> contiguousGroups =
+          extractContiguousGroups(group);
+      LLVM_DEBUG({
+        llvm::dbgs() << "Found " << contiguousGroups.size()
+                     << " contiguous groups in "
+                     << (group.isLoadGroup() ? "load" : "store") << " group\n";
+        for (const auto &contigGroup : contiguousGroups) {
+          llvm::dbgs() << "  Contiguous group with " << contigGroup.size()
+                       << " operations\n";
+        }
+      });
+    }
   });
 }
 

>From 8b266979e83431e3d06bb1f665f69b7f630c2bb9 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:14:53 +0200
Subject: [PATCH 05/52] refac

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 55 ++++++++-----------
 1 file changed, 22 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index f46dc71537ef3..9a0ba5264bc40 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -51,6 +51,18 @@ struct MemoryOpGroup {
   bool empty() const { return ops.empty(); }
 };
 
+// Helper function to extract base and index from a memory operation
+std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
+    if (auto value = getConstantIntValue(loadOp.getIndices().front()))
+      return std::make_pair(loadOp.getMemRef(), *value);
+  } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
+    if (auto value = getConstantIntValue(storeOp.getIndices().front()))
+      return std::make_pair(storeOp.getMemRef(), *value);
+  }
+  return std::nullopt;
+}
+
 // Extract contiguous groups from a MemoryOpGroup
 SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
   SmallVector<MemoryOpGroup> result;
@@ -67,24 +79,12 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
       continue;
 
     // Get base and index of current operation
-    Value base;
-    int64_t index = -1;
-    if (group.isLoadGroup()) {
-      auto loadOp = cast<memref::LoadOp>(op);
-      if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
-        index = *value;
-        base = loadOp.getMemRef();
-      }
-    } else {
-      auto storeOp = cast<memref::StoreOp>(op);
-      if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
-        index = *value;
-        base = storeOp.getMemRef();
-      }
-    }
-    if (index == -1)
+    auto baseAndIndex = getBaseAndIndex(op);
+    if (!baseAndIndex)
       continue;
 
+    auto [base, index] = *baseAndIndex;
+
     // Start a new group with this operation
     result.emplace_back(group.type);
     MemoryOpGroup &currentGroup = result.back();
@@ -103,25 +103,14 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
         if (processedOps.contains(otherOp))
           continue;
 
-        Value otherBase;
-        int64_t otherIndex = -1;
-        if (group.isLoadGroup()) {
-          auto loadOp = cast<memref::LoadOp>(otherOp);
-          if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
-            otherIndex = *value;
-            otherBase = loadOp.getMemRef();
-          }
-        } else {
-          auto storeOp = cast<memref::StoreOp>(otherOp);
-          if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
-            otherIndex = *value;
-            otherBase = storeOp.getMemRef();
-          }
-        }
+        auto otherBaseAndIndex = getBaseAndIndex(otherOp);
+        if (!otherBaseAndIndex)
+          continue;
+
+        auto [otherBase, otherIndex] = *otherBaseAndIndex;
 
         // Check if this operation has the same base and adjacent index
-        if (otherIndex != -1 && otherBase == base &&
-            otherIndex == currentGroup.ops.size()) {
+        if (otherBase == base && otherIndex == currentGroup.ops.size()) {
           currentGroup.ops.push_back(otherOp);
           processedOps.insert(otherOp);
           foundMore = true;

>From 903fee0da4587e36348f4d8359759d3de4cab652 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:35:42 +0200
Subject: [PATCH 06/52] SLPGraph

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 162 ++++++++++++++++++
 1 file changed, 162 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 9a0ba5264bc40..4355dc33648c0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -130,6 +130,160 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
   return result;
 }
 
+/// A node in the SLP graph representing a vectorizable operation
+struct SLPGraphNode {
+  Operation *op;
+  DenseSet<SLPGraphNode *> users;
+  DenseSet<SLPGraphNode *> operands;
+  bool isRoot = false;
+
+  SLPGraphNode(Operation *op) : op(op) {}
+};
+
+/// A graph of vectorizable operations
+class SLPGraph {
+public:
+  SLPGraph() = default;
+  ~SLPGraph() {
+    for (auto *node : nodes)
+      delete node;
+  }
+
+  /// Add a new node to the graph
+  SLPGraphNode *addNode(Operation *op) {
+    nodes.push_back(new SLPGraphNode(op));
+    return nodes.back();
+  }
+
+  /// Add a root node (memory operation)
+  SLPGraphNode *addRoot(Operation *op) {
+    auto *node = addNode(op);
+    node->isRoot = true;
+    return node;
+  }
+
+  /// Add a dependency edge between nodes
+  void addEdge(SLPGraphNode *from, SLPGraphNode *to) {
+    from->users.insert(to);
+    to->operands.insert(from);
+  }
+
+  /// Get all root nodes
+  SmallVector<SLPGraphNode *> getRoots() const {
+    SmallVector<SLPGraphNode *> roots;
+    for (auto *node : nodes)
+      if (node->isRoot)
+        roots.push_back(node);
+    return roots;
+  }
+
+  /// Print the graph structure
+  void print() const {
+    llvm::dbgs() << "SLP Graph Structure:\n";
+    llvm::dbgs() << "===================\n";
+
+    // First print all roots
+    llvm::dbgs() << "Roots:\n";
+    for (auto *node : nodes) {
+      if (!node->isRoot)
+        continue;
+      llvm::dbgs() << "  " << *node->op << "\n";
+      llvm::dbgs() << "    Users: ";
+      for (auto *user : node->users) {
+        llvm::dbgs() << "\n      " << *user->op;
+      }
+      llvm::dbgs() << "\n";
+    }
+
+    // Then print all non-root nodes
+    llvm::dbgs() << "\nNon-root nodes:\n";
+    for (auto *node : nodes) {
+      if (node->isRoot)
+        continue;
+      llvm::dbgs() << "  " << *node->op << "\n";
+      llvm::dbgs() << "    Operands: ";
+      for (auto *operand : node->operands) {
+        llvm::dbgs() << "\n      " << *operand->op;
+      }
+      llvm::dbgs() << "\n    Users: ";
+      for (auto *user : node->users) {
+        llvm::dbgs() << "\n      " << *user->op;
+      }
+      llvm::dbgs() << "\n";
+    }
+    llvm::dbgs() << "===================\n";
+  }
+
+private:
+  SmallVector<SLPGraphNode *> nodes;
+};
+
+/// Build the SLP graph starting from memory operation roots
+SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+  SLPGraph graph;
+  DenseMap<Operation *, SLPGraphNode *> opToNode;
+
+  // First, add all memory operations as roots
+  for (const auto &group : rootGroups) {
+    for (Operation *op : group.ops) {
+      opToNode[op] = graph.addRoot(op);
+    }
+  }
+
+  // Process each root group to build the graph
+  for (const auto &group : rootGroups) {
+    for (Operation *rootOp : group.ops) {
+      // Get the value produced by this memory operation
+      Value rootValue = group.isLoadGroup()
+                            ? cast<memref::LoadOp>(rootOp).getResult()
+                            : cast<memref::StoreOp>(rootOp).getValue();
+
+      // Find all users of this value
+      for (Operation *user : rootValue.getUsers()) {
+        // Skip if we've already processed this operation
+        if (opToNode.contains(user))
+          continue;
+
+        // Check if this is a vectorizable operation
+        if (isa<arith::AddFOp, arith::AddIOp, arith::SubFOp, arith::SubIOp,
+                arith::MulFOp, arith::MulIOp>(user)) {
+          // Check if at least one other operand is already in the graph
+          bool hasGraphOperand = false;
+          for (Value operand : user->getOperands()) {
+            if (operand == rootValue)
+              continue;
+            if (auto *defOp = operand.getDefiningOp()) {
+              if (opToNode.contains(defOp)) {
+                hasGraphOperand = true;
+                break;
+              }
+            }
+          }
+
+          // Only add the operation if it has at least one other operand in the
+          // graph
+          if (hasGraphOperand) {
+            auto *node = graph.addNode(user);
+            opToNode[user] = node;
+            graph.addEdge(opToNode[rootOp], node);
+
+            // Add edges from other operands that are in the graph
+            for (Value operand : user->getOperands()) {
+              if (auto *defOp = operand.getDefiningOp()) {
+                if (opToNode.contains(defOp)) {
+                  graph.addEdge(opToNode[defOp], node);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
+  return graph;
+}
+
 /// This pass implements the SLP vectorizer. It detects consecutive operations
 /// that can be put together into vector operations. The pass works bottom-up,
 /// across basic blocks, in search of scalars to combine.
@@ -192,6 +346,7 @@ void SLPVectorizerPass::runOnOperation() {
     SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
 
     // Process each group to find contiguous sequences
+    SmallVector<MemoryOpGroup> rootGroups;
     for (const auto &group : groups) {
       SmallVector<MemoryOpGroup> contiguousGroups =
           extractContiguousGroups(group);
@@ -204,7 +359,14 @@ void SLPVectorizerPass::runOnOperation() {
                        << " operations\n";
         }
       });
+      rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
     }
+
+    // Build the SLP graph from root groups
+    SLPGraph graph = buildSLPGraph(rootGroups);
+
+    // Print the graph structure
+    LLVM_DEBUG(graph.print());
   });
 }
 

>From 4f92ea6fffa572640cc70d82f06f336b99484642 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:53:58 +0200
Subject: [PATCH 07/52] SLPGraph

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 129 ++++++------------
 1 file changed, 38 insertions(+), 91 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 4355dc33648c0..3c4fc3a377244 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -130,29 +130,28 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
   return result;
 }
 
-/// A node in the SLP graph representing a vectorizable operation
+/// A node in the SLP graph representing a group of vectorizable operations
 struct SLPGraphNode {
-  Operation *op;
+  SmallVector<Operation *> ops;
   DenseSet<SLPGraphNode *> users;
   DenseSet<SLPGraphNode *> operands;
   bool isRoot = false;
 
-  SLPGraphNode(Operation *op) : op(op) {}
+  SLPGraphNode() = default;
+  SLPGraphNode(Operation *op) { ops.push_back(op); }
+  void addOp(Operation *op) { ops.push_back(op); }
 };
 
 /// A graph of vectorizable operations
 class SLPGraph {
 public:
   SLPGraph() = default;
-  ~SLPGraph() {
-    for (auto *node : nodes)
-      delete node;
-  }
+  ~SLPGraph() = default;
 
   /// Add a new node to the graph
   SLPGraphNode *addNode(Operation *op) {
-    nodes.push_back(new SLPGraphNode(op));
-    return nodes.back();
+    nodes.push_back(std::make_unique<SLPGraphNode>(op));
+    return nodes.back().get();
   }
 
   /// Add a root node (memory operation)
@@ -171,9 +170,9 @@ class SLPGraph {
   /// Get all root nodes
   SmallVector<SLPGraphNode *> getRoots() const {
     SmallVector<SLPGraphNode *> roots;
-    for (auto *node : nodes)
+    for (const auto &node : nodes)
       if (node->isRoot)
-        roots.push_back(node);
+        roots.push_back(node.get());
     return roots;
   }
 
@@ -184,30 +183,50 @@ class SLPGraph {
 
     // First print all roots
     llvm::dbgs() << "Roots:\n";
-    for (auto *node : nodes) {
+    for (const auto &node : nodes) {
       if (!node->isRoot)
         continue;
-      llvm::dbgs() << "  " << *node->op << "\n";
+      llvm::dbgs() << "  "
+                   << (isa<memref::LoadOp>(node->ops[0]) ? "LOAD" : "STORE")
+                   << " group with " << node->ops.size() << " operations:\n";
+      for (auto *op : node->ops) {
+        llvm::dbgs() << "    " << *op << "\n";
+      }
       llvm::dbgs() << "    Users: ";
       for (auto *user : node->users) {
-        llvm::dbgs() << "\n      " << *user->op;
+        llvm::dbgs() << "\n      Group with " << user->ops.size()
+                     << " operations:";
+        for (auto *op : user->ops) {
+          llvm::dbgs() << "\n        " << *op;
+        }
       }
       llvm::dbgs() << "\n";
     }
 
     // Then print all non-root nodes
     llvm::dbgs() << "\nNon-root nodes:\n";
-    for (auto *node : nodes) {
+    for (const auto &node : nodes) {
       if (node->isRoot)
         continue;
-      llvm::dbgs() << "  " << *node->op << "\n";
+      llvm::dbgs() << "  Group with " << node->ops.size() << " operations:\n";
+      for (auto *op : node->ops) {
+        llvm::dbgs() << "    " << *op << "\n";
+      }
       llvm::dbgs() << "    Operands: ";
       for (auto *operand : node->operands) {
-        llvm::dbgs() << "\n      " << *operand->op;
+        llvm::dbgs() << "\n      Group with " << operand->ops.size()
+                     << " operations:";
+        for (auto *op : operand->ops) {
+          llvm::dbgs() << "\n        " << *op;
+        }
       }
       llvm::dbgs() << "\n    Users: ";
       for (auto *user : node->users) {
-        llvm::dbgs() << "\n      " << *user->op;
+        llvm::dbgs() << "\n      Group with " << user->ops.size()
+                     << " operations:";
+        for (auto *op : user->ops) {
+          llvm::dbgs() << "\n        " << *op;
+        }
       }
       llvm::dbgs() << "\n";
     }
@@ -215,75 +234,9 @@ class SLPGraph {
   }
 
 private:
-  SmallVector<SLPGraphNode *> nodes;
+  SmallVector<std::unique_ptr<SLPGraphNode>> nodes;
 };
 
-/// Build the SLP graph starting from memory operation roots
-SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
-  SLPGraph graph;
-  DenseMap<Operation *, SLPGraphNode *> opToNode;
-
-  // First, add all memory operations as roots
-  for (const auto &group : rootGroups) {
-    for (Operation *op : group.ops) {
-      opToNode[op] = graph.addRoot(op);
-    }
-  }
-
-  // Process each root group to build the graph
-  for (const auto &group : rootGroups) {
-    for (Operation *rootOp : group.ops) {
-      // Get the value produced by this memory operation
-      Value rootValue = group.isLoadGroup()
-                            ? cast<memref::LoadOp>(rootOp).getResult()
-                            : cast<memref::StoreOp>(rootOp).getValue();
-
-      // Find all users of this value
-      for (Operation *user : rootValue.getUsers()) {
-        // Skip if we've already processed this operation
-        if (opToNode.contains(user))
-          continue;
-
-        // Check if this is a vectorizable operation
-        if (isa<arith::AddFOp, arith::AddIOp, arith::SubFOp, arith::SubIOp,
-                arith::MulFOp, arith::MulIOp>(user)) {
-          // Check if at least one other operand is already in the graph
-          bool hasGraphOperand = false;
-          for (Value operand : user->getOperands()) {
-            if (operand == rootValue)
-              continue;
-            if (auto *defOp = operand.getDefiningOp()) {
-              if (opToNode.contains(defOp)) {
-                hasGraphOperand = true;
-                break;
-              }
-            }
-          }
-
-          // Only add the operation if it has at least one other operand in the
-          // graph
-          if (hasGraphOperand) {
-            auto *node = graph.addNode(user);
-            opToNode[user] = node;
-            graph.addEdge(opToNode[rootOp], node);
-
-            // Add edges from other operands that are in the graph
-            for (Value operand : user->getOperands()) {
-              if (auto *defOp = operand.getDefiningOp()) {
-                if (opToNode.contains(defOp)) {
-                  graph.addEdge(opToNode[defOp], node);
-                }
-              }
-            }
-          }
-        }
-      }
-    }
-  }
-
-  return graph;
-}
-
 /// This pass implements the SLP vectorizer. It detects consecutive operations
 /// that can be put together into vector operations. The pass works bottom-up,
 /// across basic blocks, in search of scalars to combine.
@@ -361,12 +314,6 @@ void SLPVectorizerPass::runOnOperation() {
       });
       rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
     }
-
-    // Build the SLP graph from root groups
-    SLPGraph graph = buildSLPGraph(rootGroups);
-
-    // Print the graph structure
-    LLVM_DEBUG(graph.print());
   });
 }
 

>From 8dcc9bc4cf77beba2f1142cad92c011e52ae0271 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 01:04:27 +0200
Subject: [PATCH 08/52] work

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 48 ++++++++++++++++---
 1 file changed, 41 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3c4fc3a377244..8e49b622ac39b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -138,8 +138,8 @@ struct SLPGraphNode {
   bool isRoot = false;
 
   SLPGraphNode() = default;
-  SLPGraphNode(Operation *op) { ops.push_back(op); }
-  void addOp(Operation *op) { ops.push_back(op); }
+  SLPGraphNode(ArrayRef<Operation *> operations)
+      : ops(operations.begin(), operations.end()) {}
 };
 
 /// A graph of vectorizable operations
@@ -148,15 +148,23 @@ class SLPGraph {
   SLPGraph() = default;
   ~SLPGraph() = default;
 
+  // Delete copy constructor and assignment operator
+  SLPGraph(const SLPGraph &) = delete;
+  SLPGraph &operator=(const SLPGraph &) = delete;
+
+  // Allow move operations
+  SLPGraph(SLPGraph &&) = default;
+  SLPGraph &operator=(SLPGraph &&) = default;
+
   /// Add a new node to the graph
-  SLPGraphNode *addNode(Operation *op) {
-    nodes.push_back(std::make_unique<SLPGraphNode>(op));
+  SLPGraphNode *addNode(ArrayRef<Operation *> operations) {
+    nodes.push_back(std::make_unique<SLPGraphNode>(operations));
     return nodes.back().get();
   }
 
   /// Add a root node (memory operation)
-  SLPGraphNode *addRoot(Operation *op) {
-    auto *node = addNode(op);
+  SLPGraphNode *addRoot(ArrayRef<Operation *> operations) {
+    auto *node = addNode(operations);
     node->isRoot = true;
     return node;
   }
@@ -251,7 +259,25 @@ struct SLPVectorizerPass
   SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
 };
 
-} // namespace
+/// Build the SLP graph starting from memory operation groups
+SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+  SLPGraph graph;
+
+  // First, create nodes for each contiguous memory operation group
+  for (const auto &group : rootGroups) {
+    // Create a new node for this group
+    auto *node = graph.addRoot(group.ops);
+    node->isRoot = true;
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "Created " << (group.isLoadGroup() ? "LOAD" : "STORE")
+                   << " group node with " << node->ops.size()
+                   << " operations\n";
+    });
+  }
+
+  return graph;
+}
 
 SmallVector<MemoryOpGroup>
 SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
@@ -314,9 +340,17 @@ void SLPVectorizerPass::runOnOperation() {
       });
       rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
     }
+
+    // Build the SLP graph from root groups
+    SLPGraph graph = buildSLPGraph(rootGroups);
+
+    // Print the graph structure
+    LLVM_DEBUG(graph.print());
   });
 }
 
+} // namespace
+
 std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
   return std::make_unique<SLPVectorizerPass>();
 }

>From cee30a40bab9d2d474970880cc6b8c859a23ba2d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 12:55:28 +0200
Subject: [PATCH 09/52] fingerprinting

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 170 ++++++++++++++++--
 1 file changed, 158 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 8e49b622ac39b..3e6a4ca05f87d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/SHA1.h"
 
 #define DEBUG_TYPE "slp-vectorizer"
 
@@ -64,7 +65,8 @@ std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
 }
 
 // Extract contiguous groups from a MemoryOpGroup
-SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
+static SmallVector<MemoryOpGroup>
+extractContiguousGroups(const MemoryOpGroup &group) {
   SmallVector<MemoryOpGroup> result;
   if (group.ops.empty())
     return result;
@@ -133,8 +135,8 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
 /// A node in the SLP graph representing a group of vectorizable operations
 struct SLPGraphNode {
   SmallVector<Operation *> ops;
-  DenseSet<SLPGraphNode *> users;
-  DenseSet<SLPGraphNode *> operands;
+  llvm::SmallDenseSet<SLPGraphNode *> users;
+  llvm::SmallDenseSet<SLPGraphNode *> operands;
   bool isRoot = false;
 
   SLPGraphNode() = default;
@@ -148,11 +150,9 @@ class SLPGraph {
   SLPGraph() = default;
   ~SLPGraph() = default;
 
-  // Delete copy constructor and assignment operator
   SLPGraph(const SLPGraph &) = delete;
   SLPGraph &operator=(const SLPGraph &) = delete;
 
-  // Allow move operations
   SLPGraph(SLPGraph &&) = default;
   SLPGraph &operator=(SLPGraph &&) = default;
 
@@ -259,21 +259,168 @@ struct SLPVectorizerPass
   SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
 };
 
+static bool isVectorizable(Operation *op) {
+  return OpTrait::hasElementwiseMappableTraits(op);
+}
+
+using Fingerprint = std::array<uint8_t, 20>;
+
+template <typename T>
+static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
+  hasher.update(
+      ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
+}
+
+struct OperationsFingerprint {
+  OperationsFingerprint(
+      const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode)
+      : opToNode(opToNode) {}
+
+  Fingerprint getFingerprint(Operation *op) {
+    auto it = fingerprints.find(op);
+    if (it != fingerprints.end())
+      return it->second;
+
+    SmallVector<Operation *> worklist;
+    SmallVector<Operation *> toposortedOps;
+    worklist.emplace_back(op);
+    while (!worklist.empty()) {
+      Operation *op = worklist.pop_back_val();
+      toposortedOps.emplace_back(op);
+      if (opToNode.contains(op))
+        continue;
+
+      for (Value operand : op->getOperands()) {
+        auto *defOp = operand.getDefiningOp();
+        if (!defOp || !isVectorizable(defOp))
+          continue;
+
+        toposortedOps.emplace_back(defOp);
+        worklist.emplace_back(defOp);
+      }
+    }
+
+    for (Operation *op : llvm::reverse(toposortedOps)) {
+      llvm::SHA1 hasher;
+      addDataToHash(hasher, op->getName().getTypeID());
+      addDataToHash(hasher, op->getRawDictionaryAttrs());
+      addDataToHash(hasher, op->hashProperties());
+      for (Value operand : op->getOperands()) {
+        auto *defOp = operand.getDefiningOp();
+        if (!defOp)
+          continue;
+
+        auto it1 = opToNode.find(defOp);
+        if (it1 != opToNode.end()) {
+          addDataToHash(hasher, it1->second);
+          continue;
+        }
+
+        auto it2 = fingerprints.find(defOp);
+        if (it2 != fingerprints.end()) {
+          addDataToHash(hasher, it2->second);
+          continue;
+        }
+      }
+      fingerprints[op] = hasher.result();
+    }
+
+    return fingerprints[op];
+  }
+
+  void invalidate(Operation *op) {
+    if (fingerprints.contains(op))
+      fingerprints.clear();
+  }
+
+  const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode;
+  DenseMap<Operation *, Fingerprint> fingerprints;
+};
+
+static bool isEquivalent(Operation *op1, Operation *op2) {
+  if (op1->getName() != op2->getName())
+    return false;
+
+  if (op1->getRawDictionaryAttrs() != op2->getRawDictionaryAttrs())
+    return false;
+
+  return true;
+}
+
 /// Build the SLP graph starting from memory operation groups
-SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
   SLPGraph graph;
+  llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
+
+  SmallVector<SLPGraphNode *> worklist;
 
   // First, create nodes for each contiguous memory operation group
   for (const auto &group : rootGroups) {
-    // Create a new node for this group
     auto *node = graph.addRoot(group.ops);
-    node->isRoot = true;
+    for (Operation *op : group.ops)
+      opToNode[op] = node;
+
+    worklist.push_back(node);
 
     LLVM_DEBUG({
-      llvm::dbgs() << "Created " << (group.isLoadGroup() ? "LOAD" : "STORE")
-                   << " group node with " << node->ops.size()
-                   << " operations\n";
+      llvm::dbgs() << "Created root group node with " << node->ops.size()
+                   << " operations of type "
+                   << (group.type == MemoryOpGroup::Type::Load ? "Load"
+                                                               : "Store")
+                   << "\n";
     });
+
+    OperationsFingerprint fingerprints(opToNode);
+
+    auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
+      Operation *user = use.getOwner();
+      if (opToNode.contains(user) || !isVectorizable(user))
+        return;
+
+      Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
+
+      SmallVector<Operation *> currentOps;
+      currentOps.emplace_back(user);
+      for (Operation *op : ArrayRef(node->ops).drop_front()) {
+        Operation *found = nullptr;
+        for (OpOperand &opUse : op->getUses()) {
+          if (opUse.getOperandNumber() != use.getOperandNumber())
+            continue;
+
+          Operation *useOwner = opUse.getOwner();
+          if (!isEquivalent(useOwner, user) ||
+              fingerprints.getFingerprint(useOwner) != expectedFingerprint)
+            continue;
+
+          found = useOwner;
+          break;
+        }
+        if (!found)
+          break;
+
+        currentOps.push_back(found);
+      }
+
+      if (currentOps.size() == 1)
+        return;
+
+      auto *newNode = graph.addNode(currentOps);
+      graph.addEdge(node, newNode);
+      for (Operation *op : currentOps) {
+        opToNode[op] = newNode;
+        fingerprints.invalidate(op);
+      }
+
+      worklist.push_back(newNode);
+    };
+
+    while (!worklist.empty()) {
+      SLPGraphNode *node = worklist.pop_back_val();
+
+      Operation *op = node->ops.front();
+      for (OpOperand &use : op->getUses())
+        processUse(node, use);
+    }
   }
 
   return graph;
@@ -314,7 +461,6 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
 
 void SLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
-  MLIRContext *context = &getContext();
 
   // Walk all blocks recursively
   op->walk([&](Block *block) {

>From 0045a6f5abea75220af7f4777f658a2c83ae0415 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 13:58:50 +0200
Subject: [PATCH 10/52] graph

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 106 ++++++++++--------
 1 file changed, 62 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3e6a4ca05f87d..28c53efea7512 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -185,7 +185,7 @@ class SLPGraph {
   }
 
   /// Print the graph structure
-  void print() const {
+  [[maybe_unused]] void print() const {
     llvm::dbgs() << "SLP Graph Structure:\n";
     llvm::dbgs() << "===================\n";
 
@@ -348,7 +348,12 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
 }
 
 /// Build the SLP graph starting from memory operation groups
-static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
+  if (rootGroups.empty())
+    return SLPGraph();
+
+  LLVM_DEBUG(llvm::dbgs() << "=== Building SLP graph from " << rootGroups.size()
+                          << " root groups ===\n");
   SLPGraph graph;
   llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
 
@@ -365,61 +370,74 @@ static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
     LLVM_DEBUG({
       llvm::dbgs() << "Created root group node with " << node->ops.size()
                    << " operations of type "
-                   << (group.type == MemoryOpGroup::Type::Load ? "Load"
-                                                               : "Store")
-                   << "\n";
+                   << (group.isLoadGroup() ? "Load" : "Store") << "\n";
     });
+  }
 
-    OperationsFingerprint fingerprints(opToNode);
-
-    auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
-      Operation *user = use.getOwner();
-      if (opToNode.contains(user) || !isVectorizable(user))
-        return;
+  OperationsFingerprint fingerprints(opToNode);
+
+  auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
+    Operation *user = use.getOwner();
+    auto it = opToNode.find(user);
+    if (it != opToNode.end()) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "  Adding edge from " << node->ops.front()->getName()
+                 << " to " << it->first->getName() << "\n");
+      graph.addEdge(node, it->second);
+      return;
+    }
 
-      Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
+    if (!isVectorizable(user))
+      return;
 
-      SmallVector<Operation *> currentOps;
-      currentOps.emplace_back(user);
-      for (Operation *op : ArrayRef(node->ops).drop_front()) {
-        Operation *found = nullptr;
-        for (OpOperand &opUse : op->getUses()) {
-          if (opUse.getOperandNumber() != use.getOperandNumber())
-            continue;
+    Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
 
-          Operation *useOwner = opUse.getOwner();
-          if (!isEquivalent(useOwner, user) ||
-              fingerprints.getFingerprint(useOwner) != expectedFingerprint)
-            continue;
+    SmallVector<Operation *> currentOps;
+    currentOps.emplace_back(user);
+    for (Operation *op : ArrayRef(node->ops).drop_front()) {
+      Operation *found = nullptr;
+      for (OpOperand &opUse : op->getUses()) {
+        if (opUse.getOperandNumber() != use.getOperandNumber())
+          continue;
 
-          found = useOwner;
-          break;
-        }
-        if (!found)
-          break;
+        Operation *useOwner = opUse.getOwner();
+        if (!isEquivalent(useOwner, user) ||
+            fingerprints.getFingerprint(useOwner) != expectedFingerprint)
+          continue;
 
-        currentOps.push_back(found);
+        found = useOwner;
+        break;
       }
+      if (!found)
+        break;
 
-      if (currentOps.size() == 1)
-        return;
+      currentOps.push_back(found);
+    }
 
-      auto *newNode = graph.addNode(currentOps);
-      graph.addEdge(node, newNode);
-      for (Operation *op : currentOps) {
-        opToNode[op] = newNode;
-        fingerprints.invalidate(op);
-      }
+    if (currentOps.size() == 1)
+      return;
 
-      worklist.push_back(newNode);
-    };
+    auto *newNode = graph.addNode(currentOps);
+    graph.addEdge(node, newNode);
+    for (Operation *op : currentOps) {
+      opToNode[op] = newNode;
+      fingerprints.invalidate(op);
+    }
 
-    while (!worklist.empty()) {
-      SLPGraphNode *node = worklist.pop_back_val();
+    worklist.push_back(newNode);
+  };
+
+  while (!worklist.empty()) {
+    SLPGraphNode *node = worklist.pop_back_val();
+    LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->ops.size()
+                            << " operations, first op: "
+                            << node->ops.front()->getName() << "\n");
 
-      Operation *op = node->ops.front();
-      for (OpOperand &use : op->getUses())
-        processUse(node, use);
+    Operation *op = node->ops.front();
+    for (OpOperand &use : op->getUses()) {
+      processUse(node, use);
+      LLVM_DEBUG(llvm::dbgs() << "  Processing use in operation: "
+                              << use.getOwner()->getName() << "\n");
     }
   }
 

>From 8c29c3a291506d81565f605df314cf30b82f214b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 14:05:07 +0200
Subject: [PATCH 11/52] refac

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 41 ++++++++++---------
 1 file changed, 22 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 28c53efea7512..e3b39ba10373c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -159,7 +159,10 @@ class SLPGraph {
   /// Add a new node to the graph
   SLPGraphNode *addNode(ArrayRef<Operation *> operations) {
     nodes.push_back(std::make_unique<SLPGraphNode>(operations));
-    return nodes.back().get();
+    auto *node = nodes.back().get();
+    for (Operation *op : operations)
+      opToNode[op] = node;
+    return node;
   }
 
   /// Add a root node (memory operation)
@@ -184,6 +187,12 @@ class SLPGraph {
     return roots;
   }
 
+  /// Get the node associated with an operation
+  SLPGraphNode *getNodeForOp(Operation *op) const {
+    auto it = opToNode.find(op);
+    return it != opToNode.end() ? it->second : nullptr;
+  }
+
   /// Print the graph structure
   [[maybe_unused]] void print() const {
     llvm::dbgs() << "SLP Graph Structure:\n";
@@ -243,6 +252,7 @@ class SLPGraph {
 
 private:
   SmallVector<std::unique_ptr<SLPGraphNode>> nodes;
+  llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
 };
 
 /// This pass implements the SLP vectorizer. It detects consecutive operations
@@ -272,9 +282,7 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
 }
 
 struct OperationsFingerprint {
-  OperationsFingerprint(
-      const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode)
-      : opToNode(opToNode) {}
+  OperationsFingerprint(const SLPGraph &graph) : graph(graph) {}
 
   Fingerprint getFingerprint(Operation *op) {
     auto it = fingerprints.find(op);
@@ -287,7 +295,7 @@ struct OperationsFingerprint {
     while (!worklist.empty()) {
       Operation *op = worklist.pop_back_val();
       toposortedOps.emplace_back(op);
-      if (opToNode.contains(op))
+      if (graph.getNodeForOp(op))
         continue;
 
       for (Value operand : op->getOperands()) {
@@ -310,9 +318,9 @@ struct OperationsFingerprint {
         if (!defOp)
           continue;
 
-        auto it1 = opToNode.find(defOp);
-        if (it1 != opToNode.end()) {
-          addDataToHash(hasher, it1->second);
+        auto *node = graph.getNodeForOp(defOp);
+        if (node) {
+          addDataToHash(hasher, node);
           continue;
         }
 
@@ -333,7 +341,7 @@ struct OperationsFingerprint {
       fingerprints.clear();
   }
 
-  const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode;
+  const SLPGraph &graph;
   DenseMap<Operation *, Fingerprint> fingerprints;
 };
 
@@ -355,16 +363,12 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
   LLVM_DEBUG(llvm::dbgs() << "=== Building SLP graph from " << rootGroups.size()
                           << " root groups ===\n");
   SLPGraph graph;
-  llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
 
   SmallVector<SLPGraphNode *> worklist;
 
   // First, create nodes for each contiguous memory operation group
   for (const auto &group : rootGroups) {
     auto *node = graph.addRoot(group.ops);
-    for (Operation *op : group.ops)
-      opToNode[op] = node;
-
     worklist.push_back(node);
 
     LLVM_DEBUG({
@@ -374,16 +378,16 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     });
   }
 
-  OperationsFingerprint fingerprints(opToNode);
+  OperationsFingerprint fingerprints(graph);
 
   auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
     Operation *user = use.getOwner();
-    auto it = opToNode.find(user);
-    if (it != opToNode.end()) {
+    auto *existingNode = graph.getNodeForOp(user);
+    if (existingNode) {
       LLVM_DEBUG(llvm::dbgs()
                  << "  Adding edge from " << node->ops.front()->getName()
-                 << " to " << it->first->getName() << "\n");
-      graph.addEdge(node, it->second);
+                 << " to " << user->getName() << "\n");
+      graph.addEdge(node, existingNode);
       return;
     }
 
@@ -420,7 +424,6 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     auto *newNode = graph.addNode(currentOps);
     graph.addEdge(node, newNode);
     for (Operation *op : currentOps) {
-      opToNode[op] = newNode;
       fingerprints.invalidate(op);
     }
 

>From bda45c9202c3b068fadc0026976bfdd3fce79daf Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 14:59:20 +0200
Subject: [PATCH 12/52] toposort

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 89 ++++++++++++++++++-
 1 file changed, 85 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e3b39ba10373c..8f0137a12d07b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -135,8 +135,8 @@ extractContiguousGroups(const MemoryOpGroup &group) {
 /// A node in the SLP graph representing a group of vectorizable operations
 struct SLPGraphNode {
   SmallVector<Operation *> ops;
-  llvm::SmallDenseSet<SLPGraphNode *> users;
-  llvm::SmallDenseSet<SLPGraphNode *> operands;
+  SmallVector<SLPGraphNode *> users;
+  SmallVector<SLPGraphNode *> operands;
   bool isRoot = false;
 
   SLPGraphNode() = default;
@@ -174,8 +174,8 @@ class SLPGraph {
 
   /// Add a dependency edge between nodes
   void addEdge(SLPGraphNode *from, SLPGraphNode *to) {
-    from->users.insert(to);
-    to->operands.insert(from);
+    from->users.push_back(to);
+    to->operands.push_back(from);
   }
 
   /// Get all root nodes
@@ -193,6 +193,80 @@ class SLPGraph {
     return it != opToNode.end() ? it->second : nullptr;
   }
 
+  /// Topologically sort the nodes in the graph
+  SmallVector<SLPGraphNode *> topologicalSort() const {
+    SmallVector<SLPGraphNode *> result;
+    llvm::SmallDenseSet<SLPGraphNode *> visited;
+
+    SmallVector<SLPGraphNode *> stack;
+
+    // Process each node
+    for (const auto &node : nodes) {
+      if (visited.contains(node.get()))
+        continue;
+
+      stack.emplace_back(node.get());
+      while (!stack.empty()) {
+        SLPGraphNode *node = stack.pop_back_val();
+        if (visited.contains(node))
+          continue;
+
+        stack.push_back(node);
+
+        bool pushed = false;
+        for (SLPGraphNode *operand : node->operands) {
+          if (visited.contains(operand))
+            continue;
+
+          stack.push_back(operand);
+          pushed = true;
+        }
+
+        if (!pushed) {
+          visited.insert(node);
+          result.push_back(node);
+        }
+      }
+    }
+
+    return result;
+  }
+
+  /// Vectorize the operations in the graph
+  LogicalResult vectorize(IRRewriter &rewriter) {
+    if (nodes.empty())
+      return success();
+
+    LLVM_DEBUG(llvm::dbgs()
+               << "Vectorizing SLP graph with " << nodes.size() << " nodes\n");
+
+    // Get topologically sorted nodes
+    SmallVector<SLPGraphNode *> sortedNodes = topologicalSort();
+    if (sortedNodes.empty()) {
+      LLVM_DEBUG(llvm::dbgs() << "Failed to topologically sort nodes\n");
+      return failure();
+    }
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "Topologically sorted nodes:\n";
+      for (auto *node : sortedNodes) {
+        llvm::dbgs() << "  Node with " << node->ops.size()
+                     << " operations: " << node->ops.front()->getName() << "\n";
+      }
+    });
+
+    // TODO: Implement vectorization logic:
+    // 1. Process nodes in topological order
+    // 2. For each node:
+    //    a. Check if all operands are vectorized
+    //    b. Create vector operation
+    //    c. Replace scalar operations with vector operation
+    // 3. Handle memory operations (loads/stores) specially
+    // 4. Update use-def chains
+
+    return success();
+  }
+
   /// Print the graph structure
   [[maybe_unused]] void print() const {
     llvm::dbgs() << "SLP Graph Structure:\n";
@@ -513,6 +587,13 @@ void SLPVectorizerPass::runOnOperation() {
 
     // Print the graph structure
     LLVM_DEBUG(graph.print());
+
+    // Vectorize the graph
+    IRRewriter rewriter(&getContext());
+    if (failed(graph.vectorize(rewriter))) {
+      LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
+      return signalPassFailure();
+    }
   });
 }
 

>From 18666e164ac6b5a512748d01739074ff38ee178c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:30:01 +0200
Subject: [PATCH 13/52] codegen

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 68 +++++++++++++++----
 1 file changed, 56 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 8f0137a12d07b..095ad4f11a91a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -132,6 +132,10 @@ extractContiguousGroups(const MemoryOpGroup &group) {
   return result;
 }
 
+static bool isVectorizable(Operation *op) {
+  return OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1;
+}
+
 /// A node in the SLP graph representing a group of vectorizable operations
 struct SLPGraphNode {
   SmallVector<Operation *> ops;
@@ -255,14 +259,58 @@ class SLPGraph {
       }
     });
 
-    // TODO: Implement vectorization logic:
-    // 1. Process nodes in topological order
-    // 2. For each node:
-    //    a. Check if all operands are vectorized
-    //    b. Create vector operation
-    //    c. Replace scalar operations with vector operation
-    // 3. Handle memory operations (loads/stores) specially
-    // 4. Update use-def chains
+    IRMapping mapping;
+    for (auto *node : sortedNodes) {
+      if (node->users.empty() && node->operands.empty())
+        continue;
+
+      Operation *op = node->ops.front();
+      rewriter.setInsertionPoint(op);
+      Location loc = op->getLoc();
+      int64_t numElements = node->ops.size();
+      if (auto load = dyn_cast<memref::LoadOp>(op)) {
+        auto vecType =
+            VectorType::get(numElements, load.getMemRefType().getElementType());
+        Value result = rewriter.create<vector::LoadOp>(
+            loc, vecType, load.getMemRef(), load.getIndices());
+        mapping.map(load.getResult(), result);
+      } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
+        Value val = mapping.lookupOrDefault(store.getValueToStore());
+        rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
+                                         store.getIndices());
+      } else if (isVectorizable(op)) {
+        auto vecType =
+            VectorType::get(numElements, op->getResultTypes().front());
+        for (auto [i, operand] : llvm::enumerate(op->getOperands())) {
+          if (getNodeForOp(operand.getDefiningOp()))
+            continue;
+
+          SmallVector<Value> args;
+          for (Operation *defOp : node->ops)
+            args.push_back(defOp->getOperand(i));
+
+          Value vector =
+              rewriter.create<vector::FromElementsOp>(loc, vecType, args);
+          mapping.map(operand, vector);
+        }
+
+        Operation *newOp = rewriter.clone(*op, mapping);
+        auto resVectorType =
+            VectorType::get(numElements, op->getResultTypes().front());
+        newOp->getResult(0).setType(resVectorType);
+
+        mapping.map(op->getResults(), newOp->getResults());
+      } else {
+        op->emitError("unsupported operation");
+        return failure();
+      }
+    }
+
+    for (auto *node : llvm::reverse(sortedNodes)) {
+      for (Operation *op : node->ops) {
+        rewriter.eraseOp(op);
+      }
+    }
 
     return success();
   }
@@ -343,10 +391,6 @@ struct SLPVectorizerPass
   SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
 };
 
-static bool isVectorizable(Operation *op) {
-  return OpTrait::hasElementwiseMappableTraits(op);
-}
-
 using Fingerprint = std::array<uint8_t, 20>;
 
 template <typename T>

>From 7d2d82583bc95142b4dcc9970369fad70ec667ed Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:53:41 +0200
Subject: [PATCH 14/52] test

---
 mlir/test/Dialect/Vector/slp-vectorize.mlir | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index a07dd05dd16aa..266008e53ea43 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,10 +1,13 @@
 // RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
 
-// CHECK-LABEL: func @basic_slp
-func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
-  // CHECK: vector.transfer_read
-  // CHECK: arith.addi
-  // CHECK: vector.transfer_write
+// CHECK-LABEL: func @read_read_add_write
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index

>From 5274bd9fb3ac9872f8f88e7f9c7368bb4c13acba Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:57:06 +0200
Subject: [PATCH 15/52] test

---
 mlir/test/Dialect/Vector/slp-vectorize.mlir | 25 +++++++++++++++++++++
 1 file changed, 25 insertions(+)

diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 266008e53ea43..28a255f90a869 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,5 +1,30 @@
 // RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
 
+// CHECK-LABEL: func @read_write
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  memref.store %0, %arg0[%c0] : memref<8xi32>
+  memref.store %1, %arg0[%c1] : memref<8xi32>
+  memref.store %2, %arg0[%c2] : memref<8xi32>
+  memref.store %3, %arg0[%c3] : memref<8xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add_write
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {

>From e0f0c295a39895b707f9ca9cf66c02945baeed1d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 16:25:45 +0200
Subject: [PATCH 16/52] fixes

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 49 +++++++++++++------
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 36 ++++++++++++++
 2 files changed, 70 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 095ad4f11a91a..a40131a1b10ff 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -264,24 +264,13 @@ class SLPGraph {
       if (node->users.empty() && node->operands.empty())
         continue;
 
+      int64_t numElements = node->ops.size();
       Operation *op = node->ops.front();
       rewriter.setInsertionPoint(op);
       Location loc = op->getLoc();
-      int64_t numElements = node->ops.size();
-      if (auto load = dyn_cast<memref::LoadOp>(op)) {
-        auto vecType =
-            VectorType::get(numElements, load.getMemRefType().getElementType());
-        Value result = rewriter.create<vector::LoadOp>(
-            loc, vecType, load.getMemRef(), load.getIndices());
-        mapping.map(load.getResult(), result);
-      } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
-        Value val = mapping.lookupOrDefault(store.getValueToStore());
-        rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
-                                         store.getIndices());
-      } else if (isVectorizable(op)) {
-        auto vecType =
-            VectorType::get(numElements, op->getResultTypes().front());
-        for (auto [i, operand] : llvm::enumerate(op->getOperands())) {
+
+      auto handleNonVectorInputs = [&](ValueRange operands) {
+        for (auto [i, operand] : llvm::enumerate(operands)) {
           if (getNodeForOp(operand.getDefiningOp()))
             continue;
 
@@ -289,17 +278,47 @@ class SLPGraph {
           for (Operation *defOp : node->ops)
             args.push_back(defOp->getOperand(i));
 
+          auto vecType = VectorType::get(numElements, operand.getType());
           Value vector =
               rewriter.create<vector::FromElementsOp>(loc, vecType, args);
           mapping.map(operand, vector);
         }
+      };
+
+      auto handleNonVectorOutputs = [&](Value newResult) {
+        for (auto [i, result] : llvm::enumerate(node->ops)) {
+          for (OpOperand &use : result->getUses()) {
+            Operation *useOwner = use.getOwner();
+            if (getNodeForOp(useOwner))
+              continue;
+
+            Value elem = rewriter.create<vector::ExtractOp>(loc, newResult, i);
+            use.set(elem);
+          }
+        }
+      };
 
+      if (auto load = dyn_cast<memref::LoadOp>(op)) {
+        auto vecType =
+            VectorType::get(numElements, load.getMemRefType().getElementType());
+        Value result = rewriter.create<vector::LoadOp>(
+            loc, vecType, load.getMemRef(), load.getIndices());
+        mapping.map(load.getResult(), result);
+        handleNonVectorOutputs(result);
+      } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
+        handleNonVectorInputs(store.getValueToStore());
+        Value val = mapping.lookupOrDefault(store.getValueToStore());
+        rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
+                                         store.getIndices());
+      } else if (isVectorizable(op)) {
+        handleNonVectorInputs(op->getOperands());
         Operation *newOp = rewriter.clone(*op, mapping);
         auto resVectorType =
             VectorType::get(numElements, op->getResultTypes().front());
         newOp->getResult(0).setType(resVectorType);
 
         mapping.map(op->getResults(), newOp->getResults());
+        handleNonVectorOutputs(newOp->getResult(0));
       } else {
         op->emitError("unsupported operation");
         return failure();
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 28a255f90a869..2b2b91d667e00 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
 
+
 // CHECK-LABEL: func @read_write
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
@@ -24,6 +25,41 @@ func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
   return
 }
 
+// CHECK-LABEL: func @read_read_add
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     %[[R0:.*]] = vector.extract %[[RES]][0] : i32 from vector<4xi32>
+  // CHECK:     %[[R1:.*]] = vector.extract %[[RES]][1] : i32 from vector<4xi32>
+  // CHECK:     %[[R2:.*]] = vector.extract %[[RES]][2] : i32 from vector<4xi32>
+  // CHECK:     %[[R3:.*]] = vector.extract %[[RES]][3] : i32 from vector<4xi32>
+  // CHECK:     return %[[R0]], %[[R1]], %[[R2]], %[[R3]] : i32, i32, i32, i32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+  %8 = arith.addi %0, %4 : i32
+  %9 = arith.addi %1, %5 : i32
+  %10 = arith.addi %2, %6 : i32
+  %11 = arith.addi %3, %7 : i32
+
+  return %8, %9, %10, %11 : i32, i32, i32, i32
+}
+
 
 // CHECK-LABEL: func @read_read_add_write
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)

>From 659825b990e77797fcd989c3f9681048c557e949 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 17:41:39 +0200
Subject: [PATCH 17/52] fixes

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 56 +++++++++++++++++--
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 29 ++++++++++
 2 files changed, 79 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index a40131a1b10ff..ab0b3f549192f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -259,9 +259,13 @@ class SLPGraph {
       }
     });
 
+    auto isGoodNode = [&](SLPGraphNode *node) {
+      return node->users.empty() && node->operands.empty();
+    };
+
     IRMapping mapping;
     for (auto *node : sortedNodes) {
-      if (node->users.empty() && node->operands.empty())
+      if (isGoodNode(node))
         continue;
 
       int64_t numElements = node->ops.size();
@@ -326,6 +330,9 @@ class SLPGraph {
     }
 
     for (auto *node : llvm::reverse(sortedNodes)) {
+      if (isGoodNode(node))
+        continue;
+
       for (Operation *op : node->ops) {
         rewriter.eraseOp(op);
       }
@@ -560,10 +567,47 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
 
     auto *newNode = graph.addNode(currentOps);
     graph.addEdge(node, newNode);
-    for (Operation *op : currentOps) {
+    for (Operation *op : currentOps)
       fingerprints.invalidate(op);
+
+    worklist.push_back(newNode);
+  };
+
+  auto processOperands = [&](SLPGraphNode *node, Value operand, int64_t index) {
+    Operation *srcOp = operand.getDefiningOp();
+    if (!srcOp)
+      return;
+
+    auto *existingNode = graph.getNodeForOp(srcOp);
+    if (existingNode) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "  Adding edge from " << srcOp->getName() << " to "
+                 << node->ops.front()->getName() << "\n");
+      graph.addEdge(existingNode, node);
+      return;
+    }
+
+    if (!isVectorizable(srcOp))
+      return;
+
+    SmallVector<Operation *> currentOps;
+    currentOps.emplace_back(srcOp);
+    for (Operation *op : ArrayRef(node->ops).drop_front()) {
+      Operation *otherOp = op->getOperand(index).getDefiningOp();
+      if (!otherOp || !isEquivalent(otherOp, srcOp))
+        break;
+
+      currentOps.push_back(otherOp);
     }
 
+    if (currentOps.size() == 1)
+      return;
+
+    auto *newNode = graph.addNode(currentOps);
+    graph.addEdge(newNode, node);
+    for (Operation *op : currentOps)
+      fingerprints.invalidate(op);
+
     worklist.push_back(newNode);
   };
 
@@ -574,11 +618,11 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
                             << node->ops.front()->getName() << "\n");
 
     Operation *op = node->ops.front();
-    for (OpOperand &use : op->getUses()) {
+    for (OpOperand &use : op->getUses())
       processUse(node, use);
-      LLVM_DEBUG(llvm::dbgs() << "  Processing use in operation: "
-                              << use.getOwner()->getName() << "\n");
-    }
+
+    for (auto [i, operand] : llvm::enumerate(op->getOperands()))
+      processOperands(node, operand, i);
   }
 
   return graph;
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 2b2b91d667e00..036e1fcbed5d5 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -60,6 +60,35 @@ func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i3
   return %8, %9, %10, %11 : i32, i32, i32, i32
 }
 
+// CHECK-LABEL: func @add_write
+//  CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: memref<8xi32>)
+func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
+                     %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32,
+                     %arg8: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] : vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.from_elements %[[ARG4]], %[[ARG5]], %[[ARG6]], %[[ARG7]] : vector<4xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %0, %1 : vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG8]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %8 = arith.addi %arg0, %arg4 : i32
+  %9 = arith.addi %arg1, %arg5 : i32
+  %10 = arith.addi %arg2, %arg6 : i32
+  %11 = arith.addi %arg3, %arg7 : i32
+
+  memref.store %8, %arg8[%c0] : memref<8xi32>
+  memref.store %9, %arg8[%c1] : memref<8xi32>
+  memref.store %10, %arg8[%c2] : memref<8xi32>
+  memref.store %11, %arg8[%c3] : memref<8xi32>
+
+  return
+}
+
+
 
 // CHECK-LABEL: func @read_read_add_write
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)

>From edd6cb83352f9716d630ff9be7fe614c4e693605 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:07:37 +0200
Subject: [PATCH 18/52] handle size mismatch

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 21 +++++++
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 62 ++++++++++++++++++-
 2 files changed, 82 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ab0b3f549192f..f54a9aba0e6c0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -302,6 +302,16 @@ class SLPGraph {
         }
       };
 
+      auto handleVecSizeMismatch = [&](Value arg) -> Value {
+        auto srcType = cast<VectorType>(arg.getType());
+        assert(srcType.getRank() == 1);
+        if (srcType.getDimSize(0) == numElements)
+          return arg;
+
+        return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, 0,
+                                                              numElements, 1);
+      };
+
       if (auto load = dyn_cast<memref::LoadOp>(op)) {
         auto vecType =
             VectorType::get(numElements, load.getMemRefType().getElementType());
@@ -312,6 +322,7 @@ class SLPGraph {
       } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
         handleNonVectorInputs(store.getValueToStore());
         Value val = mapping.lookupOrDefault(store.getValueToStore());
+        val = handleVecSizeMismatch(val);
         rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
                                          store.getIndices());
       } else if (isVectorizable(op)) {
@@ -319,6 +330,15 @@ class SLPGraph {
         Operation *newOp = rewriter.clone(*op, mapping);
         auto resVectorType =
             VectorType::get(numElements, op->getResultTypes().front());
+
+        {
+          OpBuilder::InsertionGuard guard(rewriter);
+          rewriter.setInsertionPoint(newOp);
+          for (OpOperand &operand : newOp->getOpOperands()) {
+            Value newOperand = handleVecSizeMismatch(operand.get());
+            operand.set(newOperand);
+          }
+        }
         newOp->getResult(0).setType(resVectorType);
 
         mapping.map(op->getResults(), newOp->getResults());
@@ -701,6 +721,7 @@ void SLPVectorizerPass::runOnOperation() {
       LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
       return signalPassFailure();
     }
+    op->dump();
   });
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 036e1fcbed5d5..76592833a78b4 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -25,6 +25,31 @@ func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
   return
 }
 
+
+// CHECK-LABEL: func @read_write_size_mistamtch
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write_size_mistamtch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[RES1:.*]] = vector.extract_strided_slice %[[RES]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK:     vector.store %[[RES1]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  memref.store %0, %arg0[%c0] : memref<8xi32>
+  memref.store %1, %arg0[%c1] : memref<8xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
@@ -60,6 +85,7 @@ func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i3
   return %8, %9, %10, %11 : i32, i32, i32, i32
 }
 
+
 // CHECK-LABEL: func @add_write
 //  CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: memref<8xi32>)
 func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
@@ -89,7 +115,6 @@ func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
 }
 
 
-
 // CHECK-LABEL: func @read_read_add_write
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
@@ -125,3 +150,38 @@ func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
 
   return
 }
+
+
+// CHECK-LABEL: func @read_read_add_write_size_mismatch
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[A1:.*]] = vector.extract_strided_slice %[[A]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK:     %[[B1:.*]] = vector.extract_strided_slice %[[B]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %[[A1]], %[[B1]] : vector<2xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+  %8 = arith.addi %0, %4 : i32
+  %9 = arith.addi %1, %5 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+
+  return
+}

>From a0d251b3124357f716dea667d95ee37b83c30d0e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:53:54 +0200
Subject: [PATCH 19/52] adjacent indices

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 98 ++++++++++---------
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 25 +++++
 2 files changed, 79 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index f54a9aba0e6c0..cc252a0e32c06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,16 +52,43 @@ struct MemoryOpGroup {
   bool empty() const { return ops.empty(); }
 };
 
-// Helper function to extract base and index from a memory operation
-std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
-  if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
-    if (auto value = getConstantIntValue(loadOp.getIndices().front()))
-      return std::make_pair(loadOp.getMemRef(), *value);
-  } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
-    if (auto value = getConstantIntValue(storeOp.getIndices().front()))
-      return std::make_pair(storeOp.getMemRef(), *value);
+static ValueRange getIndices(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndices();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getIndices();
+  return {};
+}
+
+static Type getElementType(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getResult().getType();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getValueToStore().getType();
+  return {};
+}
+
+static bool isAdjacentIndices(Value idx1, Value idx2) {
+  if (auto c1 = getConstantIntValue(idx1)) {
+    if (auto c2 = getConstantIntValue(idx2))
+      return *c1 + 1 == *c2;
   }
-  return std::nullopt;
+  return false;
+}
+
+static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
+  if (idx1.empty() || idx1.size() != idx2.size())
+    return false;
+
+  if (idx1.drop_back() != idx2.drop_back())
+    return false;
+
+  return isAdjacentIndices(idx1.back(), idx2.back());
+}
+
+static bool isAdjacentIndices(Operation *op1, Operation *op2) {
+  return getElementType(op1) == getElementType(op2) &&
+         isAdjacentIndices(getIndices(op1), getIndices(op2));
 }
 
 // Extract contiguous groups from a MemoryOpGroup
@@ -71,64 +98,48 @@ extractContiguousGroups(const MemoryOpGroup &group) {
   if (group.ops.empty())
     return result;
 
-  // Keep track of which operations we've processed
-  DenseSet<Operation *> processedOps;
+  llvm::SmallDenseSet<Operation *> processedOps;
 
-  // Process each operation
   for (Operation *op : group.ops) {
-    // Skip if we've already processed this operation
     if (processedOps.contains(op))
       continue;
 
-    // Get base and index of current operation
-    auto baseAndIndex = getBaseAndIndex(op);
-    if (!baseAndIndex)
-      continue;
-
-    auto [base, index] = *baseAndIndex;
-
     // Start a new group with this operation
     result.emplace_back(group.type);
     MemoryOpGroup &currentGroup = result.back();
-    currentGroup.ops.push_back(op);
+    auto &currentOps = currentGroup.ops;
+    currentOps.push_back(op);
     processedOps.insert(op);
 
-    LLVM_DEBUG(llvm::dbgs() << "Starting new group at base " << base
-                            << " index " << index << "\n");
-
-    // Try to find operations with adjacent indices
     bool foundMore;
     do {
       foundMore = false;
-      // Look for operations with index+1
       for (Operation *otherOp : group.ops) {
         if (processedOps.contains(otherOp))
           continue;
 
-        auto otherBaseAndIndex = getBaseAndIndex(otherOp);
-        if (!otherBaseAndIndex)
-          continue;
-
-        auto [otherBase, otherIndex] = *otherBaseAndIndex;
-
-        // Check if this operation has the same base and adjacent index
-        if (otherBase == base && otherIndex == currentGroup.ops.size()) {
-          currentGroup.ops.push_back(otherOp);
+        Operation *firstOp = currentOps.front();
+        Operation *lastOp = currentOps.back();
+        if (isAdjacentIndices(otherOp, firstOp)) {
+          currentOps.insert(currentOps.begin(), otherOp);
+          processedOps.insert(otherOp);
+          foundMore = true;
+        } else if (isAdjacentIndices(lastOp, otherOp)) {
+          currentOps.push_back(otherOp);
           processedOps.insert(otherOp);
           foundMore = true;
-          LLVM_DEBUG(llvm::dbgs()
-                     << "Added operation with index " << otherIndex << "\n");
-          break;
         }
       }
     } while (foundMore);
-  }
 
-  // Remove empty groups
-  result.erase(std::remove_if(result.begin(), result.end(),
-                              [](const MemoryOpGroup &g) { return g.empty(); }),
-               result.end());
+    if (currentOps.size() <= 1) {
+      result.pop_back();
+      continue;
+    }
 
+    LLVM_DEBUG(llvm::dbgs() << "Extracted contiguous group with "
+                            << currentGroup.ops.size() << " operations\n");
+  }
   return result;
 }
 
@@ -721,7 +732,6 @@ void SLPVectorizerPass::runOnOperation() {
       LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
       return signalPassFailure();
     }
-    op->dump();
   });
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 76592833a78b4..6be405ad078b9 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -50,6 +50,31 @@ func.func @read_write_size_mistamtch(%arg0: memref<8xi32>, %arg1: memref<8xi32>)
 }
 
 
+// CHECK-LABEL: func @read_write_interleaved
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+
+  memref.store %1, %arg0[%c1] : memref<8xi32>
+  memref.store %0, %arg0[%c0] : memref<8xi32>
+  memref.store %3, %arg0[%c3] : memref<8xi32>
+  memref.store %2, %arg0[%c2] : memref<8xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){

>From 785f7568349377d9bfbefda73c5fc92c9122184d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:57:52 +0200
Subject: [PATCH 20/52] test

---
 mlir/test/Dialect/Vector/slp-vectorize.mlir | 38 +++++++++++++++++++++
 1 file changed, 38 insertions(+)

diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 6be405ad078b9..9c5005f807c71 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -210,3 +210,41 @@ func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref
 
   return
 }
+
+
+// CHECK-LABEL: func @read_read_add_write_interleaved
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+  %11 = arith.addi %3, %7 : i32
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %8 = arith.addi %0, %4 : i32
+
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %10 = arith.addi %2, %6 : i32
+
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %9 = arith.addi %1, %5 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+
+  return
+}

>From 6c18d433545d8854688caa140afe7fc18c961add Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:07:26 +0200
Subject: [PATCH 21/52] fixes and test

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 11 +++-
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 54 +++++++++++++++++++
 2 files changed, 64 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index cc252a0e32c06..3ff46093d9fbe 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,6 +52,14 @@ struct MemoryOpGroup {
   bool empty() const { return ops.empty(); }
 };
 
+static Value getBase(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getMemRef();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getMemRef();
+  return {};
+}
+
 static ValueRange getIndices(Operation *op) {
   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
     return loadOp.getIndices();
@@ -87,7 +95,8 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
 }
 
 static bool isAdjacentIndices(Operation *op1, Operation *op2) {
-  return getElementType(op1) == getElementType(op2) &&
+  return getBase(op1) == getBase(op2) &&
+         getElementType(op1) == getElementType(op2) &&
          isAdjacentIndices(getIndices(op1), getIndices(op2));
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 9c5005f807c71..820fbf2d260cd 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -248,3 +248,57 @@ func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8
 
   return
 }
+
+
+// CHECK-LABEL: func @read_read_add_add_write
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>
+//  CHECK-SAME: , %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
+                                   %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[ADD1:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     %[[C:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] : vector<4xi32>
+  // CHECK:     %[[ADD2:.*]] = arith.addi %[[A]], %[[C]] : vector<4xi32>
+  // CHECK:     vector.store %[[ADD1]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     vector.store %[[ADD2]], %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+  %8 = arith.addi %0, %4 : i32
+  %12 = arith.addi %0, %arg2 : i32
+
+  %13 = arith.addi %1, %arg3 : i32
+  %9 = arith.addi %1, %5 : i32
+
+  %10 = arith.addi %2, %6 : i32
+  %14 = arith.addi %2, %arg4 : i32
+
+  %15 = arith.addi %3, %arg5 : i32
+  %11 = arith.addi %3, %7 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
+
+  memref.store %12, %arg1[%c0] : memref<8xi32>
+  memref.store %13, %arg1[%c1] : memref<8xi32>
+  memref.store %14, %arg1[%c2] : memref<8xi32>
+  memref.store %15, %arg1[%c3] : memref<8xi32>
+
+  return
+}

>From 3e912a3d5ab5003819eafec5d17a957ad3bba9f5 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:22:03 +0200
Subject: [PATCH 22/52] better side effects handling

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 94 +++++++++++--------
 1 file changed, 55 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3ff46093d9fbe..6cb6faa486702 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,6 +52,61 @@ struct MemoryOpGroup {
   bool empty() const { return ops.empty(); }
 };
 
+static bool isReadOp(Operation *op) {
+  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!effectInterface)
+    return true;
+
+  return effectInterface.hasEffect<MemoryEffects::Read>();
+}
+
+static bool isWriteOp(Operation *op) {
+  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!effectInterface)
+    return true;
+
+  return effectInterface.hasEffect<MemoryEffects::Write>();
+}
+
+/// Collect all memory operations in the block into groups.
+/// Each group contains either all loads or all stores, uninterrupted by
+/// operations of the other type.
+static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
+  SmallVector<MemoryOpGroup> groups;
+  MemoryOpGroup *currentGroup = nullptr;
+
+  for (Operation &op : block) {
+    if (currentGroup) {
+      if (currentGroup->isLoadGroup() && isWriteOp(&op)) {
+        currentGroup = nullptr;
+      } else if (currentGroup->isStoreGroup() && isReadOp(&op)) {
+        currentGroup = nullptr;
+      }
+    }
+
+    if (!isa<memref::LoadOp, memref::StoreOp>(op))
+      continue;
+
+    bool isLoad = isReadOp(&op);
+    MemoryOpGroup::Type type =
+        isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
+
+    if (!currentGroup) {
+      groups.emplace_back(type);
+      currentGroup = &groups.back();
+    }
+
+    currentGroup->ops.push_back(&op);
+  }
+
+  // Remove empty groups
+  groups.erase(std::remove_if(groups.begin(), groups.end(),
+                              [](const MemoryOpGroup &g) { return g.empty(); }),
+               groups.end());
+
+  return groups;
+}
+
 static Value getBase(Operation *op) {
   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
     return loadOp.getMemRef();
@@ -449,12 +504,6 @@ class SLPGraph {
 struct SLPVectorizerPass
     : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
   void runOnOperation() override;
-
-private:
-  /// Collect all memory operations in the block into groups.
-  /// Each group contains either all loads or all stores, uninterrupted by
-  /// operations of the other type.
-  SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
 };
 
 using Fingerprint = std::array<uint8_t, 20>;
@@ -668,39 +717,6 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
   return graph;
 }
 
-SmallVector<MemoryOpGroup>
-SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
-  SmallVector<MemoryOpGroup> groups;
-  MemoryOpGroup *currentGroup = nullptr;
-
-  for (Operation &op : block) {
-    // Skip non-memory operations
-    if (!isa<memref::LoadOp, memref::StoreOp>(op))
-      continue;
-
-    bool isLoad = isa<memref::LoadOp>(op);
-    MemoryOpGroup::Type type =
-        isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
-
-    // Start a new group if:
-    // 1. We don't have a current group, or
-    // 2. The current operation is a different type than the current group
-    if (!currentGroup || currentGroup->type != type) {
-      groups.emplace_back(type);
-      currentGroup = &groups.back();
-    }
-
-    currentGroup->ops.push_back(&op);
-  }
-
-  // Remove empty groups
-  groups.erase(std::remove_if(groups.begin(), groups.end(),
-                              [](const MemoryOpGroup &g) { return g.empty(); }),
-               groups.end());
-
-  return groups;
-}
-
 void SLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
 

>From c851b5da448d4e25071495d773899205d37d2614 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:44:47 +0200
Subject: [PATCH 23/52] cleanup

---
 .../mlir/Dialect/Vector/Transforms/Passes.h   |  3 --
 .../mlir/Dialect/Vector/Transforms/Passes.td  | 14 ++++--
 .../Vector/Transforms/SLPVectorizer.cpp       | 49 ++++++++++++-------
 mlir/test/Dialect/Vector/slp-vectorize.mlir   |  2 +-
 4 files changed, 43 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 43112f084dc60..5667f4fa95ace 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -25,9 +25,6 @@ std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
     VectorMultiReductionLowering option =
         VectorMultiReductionLowering::InnerParallel);
 
-/// Creates a pass that implements the SLP vectorizer.
-std::unique_ptr<Pass> createSLPVectorizerPass();
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 94ccd61cb5170..d5c31c9f78409 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,15 +34,21 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
   ];
 }
 
-def SLPVectorizer : Pass<"slp-vectorizer", "ModuleOp"> {
+def GreedySLPVectorizer : Pass<"greedy-slp-vectorizer"> {
   let summary = "SLP Vectorizer Pass";
   let description = [{
     This pass implements the SLP (Superword Level Parallelism) vectorizer.
     It detects consecutive operations that can be put together into vector
-    operations. The pass works bottom-up, across basic blocks, in search of
-    scalars to combine.
+    operations. The pass works bi-directionaly, starting from reads or stores,
+    in search of scalars to combine.
+
+    This is greedy vectorizer, it doesn't have any cost model (yet) and it tries
+    to create vector ops if we have at least 2 potential ops.
+
+    It doesn't check if target actually supports resulted vectors either, user
+    will need a follow up pass which will split large and/or unaliggned vectors
+    into sizes actually supported by the target.
   }];
-  let constructor = "mlir::vector::createSLPVectorizerPass()";
   let dependentDialects = ["mlir::vector::VectorDialect"];
 }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 6cb6faa486702..d7c2dc3845cac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -27,7 +27,7 @@
 
 namespace mlir {
 namespace vector {
-#define GEN_PASS_DEF_SLPVECTORIZER
+#define GEN_PASS_DEF_GREEDYSLPVECTORIZER
 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
 } // namespace vector
 } // namespace mlir
@@ -115,6 +115,19 @@ static Value getBase(Operation *op) {
   return {};
 }
 
+static bool isContiguousLastDim(Value val) {
+  auto memrefType = dyn_cast<MemRefType>(val.getType());
+  if (!memrefType)
+    return false;
+
+  int64_t offset;
+  SmallVector<int64_t> strides;
+  if (failed(memrefType.getStridesAndOffset(strides, offset)))
+    return false;
+
+  return !strides.empty() && strides.back() == 1;
+}
+
 static ValueRange getIndices(Operation *op) {
   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
     return loadOp.getIndices();
@@ -150,8 +163,15 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
 }
 
 static bool isAdjacentIndices(Operation *op1, Operation *op2) {
-  return getBase(op1) == getBase(op2) &&
-         getElementType(op1) == getElementType(op2) &&
+  Value base1 = getBase(op1);
+  Value base2 = getBase(op2);
+  if (base1 != base2)
+    return false;
+
+  if (!isContiguousLastDim(base1))
+    return false;
+
+  return getElementType(op1) == getElementType(op2) &&
          isAdjacentIndices(getIndices(op1), getIndices(op2));
 }
 
@@ -498,11 +518,9 @@ class SLPGraph {
   llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
 };
 
-/// This pass implements the SLP vectorizer. It detects consecutive operations
-/// that can be put together into vector operations. The pass works bottom-up,
-/// across basic blocks, in search of scalars to combine.
-struct SLPVectorizerPass
-    : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
+struct GreedySLPVectorizerPass
+    : public mlir::vector::impl::GreedySLPVectorizerBase<
+          GreedySLPVectorizerPass> {
   void runOnOperation() override;
 };
 
@@ -717,11 +735,11 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
   return graph;
 }
 
-void SLPVectorizerPass::runOnOperation() {
+void GreedySLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
 
   // Walk all blocks recursively
-  op->walk([&](Block *block) {
+  op->walk([&](Block *block) -> WalkResult {
     LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
                             << block->getParentOp()->getName() << "\n");
 
@@ -747,21 +765,18 @@ void SLPVectorizerPass::runOnOperation() {
 
     // Build the SLP graph from root groups
     SLPGraph graph = buildSLPGraph(rootGroups);
-
-    // Print the graph structure
     LLVM_DEBUG(graph.print());
 
     // Vectorize the graph
     IRRewriter rewriter(&getContext());
     if (failed(graph.vectorize(rewriter))) {
       LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
-      return signalPassFailure();
+      signalPassFailure();
+      return WalkResult::interrupt();
     }
+
+    return WalkResult::advance();
   });
 }
 
 } // namespace
-
-std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
-  return std::make_unique<SLPVectorizerPass>();
-}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 820fbf2d260cd..2e9298d11ed05 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
+// RUN: mlir-opt %s --greedy-slp-vectorizer | FileCheck %s
 
 
 // CHECK-LABEL: func @read_write

>From bcabdf75314c15271652e9d34245e06a199dca5c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 20:10:02 +0200
Subject: [PATCH 24/52] cleanup

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 80 +++++++++++++------
 1 file changed, 57 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index d7c2dc3845cac..24059ec355b30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -49,7 +49,6 @@ struct MemoryOpGroup {
   bool isStoreGroup() const { return type == Type::Store; }
 
   size_t size() const { return ops.size(); }
-  bool empty() const { return ops.empty(); }
 };
 
 static bool isReadOp(Operation *op) {
@@ -99,11 +98,6 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
     currentGroup->ops.push_back(&op);
   }
 
-  // Remove empty groups
-  groups.erase(std::remove_if(groups.begin(), groups.end(),
-                              [](const MemoryOpGroup &g) { return g.empty(); }),
-               groups.end());
-
   return groups;
 }
 
@@ -144,14 +138,19 @@ static Type getElementType(Operation *op) {
   return {};
 }
 
+/// Check if two indices are consecutive, i.e fastest index differs by 1.
 static bool isAdjacentIndices(Value idx1, Value idx2) {
   if (auto c1 = getConstantIntValue(idx1)) {
     if (auto c2 = getConstantIntValue(idx2))
       return *c1 + 1 == *c2;
   }
+
+  // TODO: Check arith.add, affine.apply, etc
   return false;
 }
 
+/// Check if two ranges of indices are consecutive, i.e fastest index differs
+/// by 1 and all other indices are the same.
 static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
   if (idx1.empty() || idx1.size() != idx2.size())
     return false;
@@ -162,7 +161,10 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
   return isAdjacentIndices(idx1.back(), idx2.back());
 }
 
-static bool isAdjacentIndices(Operation *op1, Operation *op2) {
+/// Check if two operations are adjacent and can be combined into a vector op.
+/// This is done by checking if the base memrefs are the same, the last
+/// dimension is contiguous, and the element types and indices are compatible
+static bool isAdjacentOps(Operation *op1, Operation *op2) {
   Value base1 = getBase(op1);
   Value base2 = getBase(op2);
   if (base1 != base2)
@@ -195,6 +197,8 @@ extractContiguousGroups(const MemoryOpGroup &group) {
     currentOps.push_back(op);
     processedOps.insert(op);
 
+    // Keep adding ops to the beginning or end of the current group until no
+    // more ops can be added.
     bool foundMore;
     do {
       foundMore = false;
@@ -204,11 +208,11 @@ extractContiguousGroups(const MemoryOpGroup &group) {
 
         Operation *firstOp = currentOps.front();
         Operation *lastOp = currentOps.back();
-        if (isAdjacentIndices(otherOp, firstOp)) {
+        if (isAdjacentOps(otherOp, firstOp)) {
           currentOps.insert(currentOps.begin(), otherOp);
           processedOps.insert(otherOp);
           foundMore = true;
-        } else if (isAdjacentIndices(lastOp, otherOp)) {
+        } else if (isAdjacentOps(lastOp, otherOp)) {
           currentOps.push_back(otherOp);
           processedOps.insert(otherOp);
           foundMore = true;
@@ -222,7 +226,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
     }
 
     LLVM_DEBUG(llvm::dbgs() << "Extracted contiguous group with "
-                            << currentGroup.ops.size() << " operations\n");
+                            << currentGroup.size() << " operations\n");
   }
   return result;
 }
@@ -241,6 +245,8 @@ struct SLPGraphNode {
   SLPGraphNode() = default;
   SLPGraphNode(ArrayRef<Operation *> operations)
       : ops(operations.begin(), operations.end()) {}
+
+  size_t size() const { return ops.size(); }
 };
 
 /// A graph of vectorizable operations
@@ -349,7 +355,7 @@ class SLPGraph {
     LLVM_DEBUG({
       llvm::dbgs() << "Topologically sorted nodes:\n";
       for (auto *node : sortedNodes) {
-        llvm::dbgs() << "  Node with " << node->ops.size()
+        llvm::dbgs() << "  Node with " << node->size()
                      << " operations: " << node->ops.front()->getName() << "\n";
       }
     });
@@ -363,7 +369,7 @@ class SLPGraph {
       if (isGoodNode(node))
         continue;
 
-      int64_t numElements = node->ops.size();
+      int64_t numElements = node->size();
       Operation *op = node->ops.front();
       rewriter.setInsertionPoint(op);
       Location loc = op->getLoc();
@@ -467,15 +473,15 @@ class SLPGraph {
       if (!node->isRoot)
         continue;
       llvm::dbgs() << "  "
-                   << (isa<memref::LoadOp>(node->ops[0]) ? "LOAD" : "STORE")
-                   << " group with " << node->ops.size() << " operations:\n";
+                   << (isa<memref::LoadOp>(node->ops.front()) ? "LOAD"
+                                                              : "STORE")
+                   << " group with " << node->size() << " operations:\n";
       for (auto *op : node->ops) {
         llvm::dbgs() << "    " << *op << "\n";
       }
       llvm::dbgs() << "    Users: ";
       for (auto *user : node->users) {
-        llvm::dbgs() << "\n      Group with " << user->ops.size()
-                     << " operations:";
+        llvm::dbgs() << "\n      Group with " << user->size() << " operations:";
         for (auto *op : user->ops) {
           llvm::dbgs() << "\n        " << *op;
         }
@@ -488,13 +494,13 @@ class SLPGraph {
     for (const auto &node : nodes) {
       if (node->isRoot)
         continue;
-      llvm::dbgs() << "  Group with " << node->ops.size() << " operations:\n";
+      llvm::dbgs() << "  Group with " << node->size() << " operations:\n";
       for (auto *op : node->ops) {
         llvm::dbgs() << "    " << *op << "\n";
       }
       llvm::dbgs() << "    Operands: ";
       for (auto *operand : node->operands) {
-        llvm::dbgs() << "\n      Group with " << operand->ops.size()
+        llvm::dbgs() << "\n      Group with " << operand->size()
                      << " operations:";
         for (auto *op : operand->ops) {
           llvm::dbgs() << "\n        " << *op;
@@ -502,8 +508,7 @@ class SLPGraph {
       }
       llvm::dbgs() << "\n    Users: ";
       for (auto *user : node->users) {
-        llvm::dbgs() << "\n      Group with " << user->ops.size()
-                     << " operations:";
+        llvm::dbgs() << "\n      Group with " << user->size() << " operations:";
         for (auto *op : user->ops) {
           llvm::dbgs() << "\n        " << *op;
         }
@@ -518,6 +523,28 @@ class SLPGraph {
   llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
 };
 
+/// This pass implements the greedy SLP vectorizer. It detects consecutive
+/// operations that can be put together into vector operations. The pass works
+/// bi-directionaly, starting from reads or stores, in search of scalars to
+/// combine.
+///
+/// Pass is split into multiple steps:
+/// 1. Collect memory operation groups within same block.
+/// Group is either multiple loads uninterrupted by stores or multiple stores
+/// uninterrupted by loads.
+///
+/// 2. Extract contiguous groups from memory operation groups, based on the
+/// ops base memrefs, load/store element types, and indices.
+///
+/// 3. Build SLP graph from contiguous groups. This is done by going both
+/// top-down and bottom-up through uses/operands respectively, starting from
+/// contiguous memory operation groups.
+///
+/// 4. Vectorize SLP graph. This is done by topological sort of the graph and
+/// vectorizing each node in the order of the sort.
+///
+/// Vectorization is done by cloning the operations and mapping the operands and
+/// results.
 struct GreedySLPVectorizerPass
     : public mlir::vector::impl::GreedySLPVectorizerBase<
           GreedySLPVectorizerPass> {
@@ -532,6 +559,10 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
       ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
 }
 
+/// SLP vectorizer is bi-directional, so when we go top-down we can can have
+/// multiple users with the same immediate op type, this class tries to compute
+/// fingerprint for such ops based on the entire ops graph to maximize further
+/// scalar ops merging.
 struct OperationsFingerprint {
   OperationsFingerprint(const SLPGraph &graph) : graph(graph) {}
 
@@ -606,7 +637,8 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
   return true;
 }
 
-/// Build the SLP graph starting from memory operation groups
+/// Build the SLP graph starting from memory operation groups and going both
+/// top-down and bottom-up through uses/operands respectively.
 static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
   if (rootGroups.empty())
     return SLPGraph();
@@ -623,7 +655,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     worklist.push_back(node);
 
     LLVM_DEBUG({
-      llvm::dbgs() << "Created root group node with " << node->ops.size()
+      llvm::dbgs() << "Created root group node with " << node->size()
                    << " operations of type "
                    << (group.isLoadGroup() ? "Load" : "Store") << "\n";
     });
@@ -631,6 +663,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
 
   OperationsFingerprint fingerprints(graph);
 
+  // Process node uses, going top-down.
   auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
     Operation *user = use.getOwner();
     auto *existingNode = graph.getNodeForOp(user);
@@ -680,6 +713,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     worklist.push_back(newNode);
   };
 
+  // Process node operands, going bottom-up.
   auto processOperands = [&](SLPGraphNode *node, Value operand, int64_t index) {
     Operation *srcOp = operand.getDefiningOp();
     if (!srcOp)
@@ -720,7 +754,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
 
   while (!worklist.empty()) {
     SLPGraphNode *node = worklist.pop_back_val();
-    LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->ops.size()
+    LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->size()
                             << " operations, first op: "
                             << node->ops.front()->getName() << "\n");
 

>From 5613035ee349cd4d5e157dbb82eed380ccad3920 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 22:08:35 +0200
Subject: [PATCH 25/52] check arith.add indices

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 13 ++++-
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 54 +++++++++++++++++++
 2 files changed, 66 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 24059ec355b30..aa2f3108712f1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -145,7 +145,18 @@ static bool isAdjacentIndices(Value idx1, Value idx2) {
       return *c1 + 1 == *c2;
   }
 
-  // TODO: Check arith.add, affine.apply, etc
+  if (auto addOp2 = idx2.getDefiningOp<arith::AddIOp>()) {
+    if (addOp2.getLhs() == idx1 && getConstantIntValue(addOp2.getRhs()) == 1)
+      return true;
+
+    if (auto addOp1 = idx1.getDefiningOp<arith::AddIOp>()) {
+      if (addOp1.getLhs() == addOp2.getLhs() &&
+          isAdjacentIndices(addOp1.getRhs(), addOp2.getRhs()))
+        return true;
+    }
+  }
+
+  // TODO: affine.apply, etc
   return false;
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 2e9298d11ed05..edb722472995d 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -75,6 +75,60 @@ func.func @read_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
 }
 
 
+// CHECK-LABEL: func @read_write_add_index
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index)
+func.func @read_write_add_index(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index) {
+  // CHECK:     %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %ind1 = arith.addi %arg2, %c1 : index
+  %ind2 = arith.addi %arg2, %c2 : index
+  %ind3 = arith.addi %arg2, %c3 : index
+
+  %0 = memref.load %arg0[%arg2] : memref<8xi32>
+  %1 = memref.load %arg0[%ind1] : memref<8xi32>
+  %2 = memref.load %arg0[%ind2] : memref<8xi32>
+  %3 = memref.load %arg0[%ind3] : memref<8xi32>
+
+  memref.store %0, %arg0[%arg2] : memref<8xi32>
+  memref.store %1, %arg0[%ind1] : memref<8xi32>
+  memref.store %2, %arg0[%ind2] : memref<8xi32>
+  memref.store %3, %arg0[%ind3] : memref<8xi32>
+
+  return
+}
+
+
+// CHECK-LABEL: func @read_write_add_index_interleaved
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index)
+func.func @read_write_add_index_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index) {
+  // CHECK:     %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %ind1 = arith.addi %arg2, %c1 : index
+  %ind2 = arith.addi %arg2, %c2 : index
+  %ind3 = arith.addi %arg2, %c3 : index
+
+  %0 = memref.load %arg0[%arg2] : memref<8xi32>
+  %1 = memref.load %arg0[%ind1] : memref<8xi32>
+  %3 = memref.load %arg0[%ind3] : memref<8xi32>
+  %2 = memref.load %arg0[%ind2] : memref<8xi32>
+
+  memref.store %3, %arg0[%ind3] : memref<8xi32>
+  memref.store %0, %arg0[%arg2] : memref<8xi32>
+  memref.store %1, %arg0[%ind1] : memref<8xi32>
+  memref.store %2, %arg0[%ind2] : memref<8xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){

>From a43bc4144f656a575cc703331ef34e6feac422a4 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 23:41:34 +0200
Subject: [PATCH 26/52] fix vecor sizes

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 24 ++++++----
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 47 +++++++++++++++++++
 2 files changed, 61 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index aa2f3108712f1..dfd4747f615ee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -12,14 +12,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/SHA1.h"
 
@@ -371,15 +368,24 @@ class SLPGraph {
       }
     });
 
-    auto isGoodNode = [&](SLPGraphNode *node) {
+    auto isBadNode = [&](SLPGraphNode *node) {
       return node->users.empty() && node->operands.empty();
     };
 
-    IRMapping mapping;
+    // Update vec sizes if inputs are smaller.
     for (auto *node : sortedNodes) {
-      if (isGoodNode(node))
-        continue;
+      size_t size = node->size();
+      for (auto *operand : node->operands)
+        size = std::min(size, operand->size());
+
+      node->ops.resize(size);
+    }
+
+    // Remove nodes that are not good (have users or operands)
+    llvm::erase_if(sortedNodes, isBadNode);
 
+    IRMapping mapping;
+    for (auto *node : sortedNodes) {
       int64_t numElements = node->size();
       Operation *op = node->ops.front();
       rewriter.setInsertionPoint(op);
@@ -462,14 +468,12 @@ class SLPGraph {
     }
 
     for (auto *node : llvm::reverse(sortedNodes)) {
-      if (isGoodNode(node))
-        continue;
-
       for (Operation *op : node->ops) {
         rewriter.eraseOp(op);
       }
     }
 
+    LLVM_DEBUG(llvm::dbgs() << "Vectorization completed successfully\n");
     return success();
   }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index edb722472995d..7ad077d8fd78c 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -356,3 +356,50 @@ func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
 
   return
 }
+
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func @read_read_add_write_interleaved_use
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[C3:.*]] = arith.constant 3 : index
+  // CHECK:     %[[V0:.*]] = memref.load %[[ARG0]][%[[C3]]] : memref<8xi32>
+  // CHECK:     %[[V1:.*]] = memref.load %[[ARG1]][%[[C3]]] : memref<8xi32>
+  // CHECK:     call @use(%[[V0]]) : (i32) -> ()
+  // CHECK:     %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : i32
+  // CHECK:     %[[V3:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+  // CHECK:     %[[V4:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+  // CHECK:     %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : vector<3xi32>
+  // CHECK:     vector.store %[[V5]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+  // CHECK:     memref.store %[[V2]], %[[ARG0]][%[[C3]]] : memref<8xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+  call @use(%3) : (i32) -> ()
+  %11 = arith.addi %3, %7 : i32
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %8 = arith.addi %0, %4 : i32
+
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %10 = arith.addi %2, %6 : i32
+
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %9 = arith.addi %1, %5 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+
+  return
+}

>From dfd44a4b78b389d9703f49fbcfc835e308baeac6 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 23:48:08 +0200
Subject: [PATCH 27/52] fix op insertion point

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 12 ++++-
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 44 +++++++++++++++++++
 2 files changed, 55 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index dfd4747f615ee..ab5bfd94de49c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -255,6 +255,16 @@ struct SLPGraphNode {
       : ops(operations.begin(), operations.end()) {}
 
   size_t size() const { return ops.size(); }
+
+  Operation *getEarliestOp() const {
+    assert(!ops.empty() && "empty node");
+    Operation *ret = ops.front();
+    for (Operation *op : ArrayRef(ops).drop_front()) {
+      if (op->isBeforeInBlock(ret))
+        ret = op;
+    }
+    return ret;
+  }
 };
 
 /// A graph of vectorizable operations
@@ -388,7 +398,7 @@ class SLPGraph {
     for (auto *node : sortedNodes) {
       int64_t numElements = node->size();
       Operation *op = node->ops.front();
-      rewriter.setInsertionPoint(op);
+      rewriter.setInsertionPoint(node->getEarliestOp());
       Location loc = op->getLoc();
 
       auto handleNonVectorInputs = [&](ValueRange operands) {
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 7ad077d8fd78c..9d06a1faa07b2 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -403,3 +403,47 @@ func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memr
 
   return
 }
+
+
+// CHECK-LABEL: func @read_read_add_write_interleaved_use_add
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[V1:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
+  // CHECK:     %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[V3:.*]] = vector.extract %[[V2]][3] : i32 from vector<4xi32>
+  // CHECK:     %[[V4:.*]] = arith.subi %[[V1]], %[[V3]] : i32
+  // CHECK:     %[[V5:.*]] = arith.addi %[[V0]], %[[V2]] : vector<4xi32>
+  // CHECK:     vector.store %[[V5]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     call @use(%[[V4]]) : (i32) -> ()
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+  %12 = arith.subi %3, %7 : i32
+  %11 = arith.addi %3, %7 : i32
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %8 = arith.addi %0, %4 : i32
+
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %10 = arith.addi %2, %6 : i32
+
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %9 = arith.addi %1, %5 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+
+  call @use(%12) : (i32) -> ()
+  return
+}

>From 82da589254683c697e461f12c0e486bbcb6e9de1 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 19 May 2025 11:00:11 +0200
Subject: [PATCH 28/52] check same block

---
 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ab5bfd94de49c..ec1c41dbd7b69 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -659,6 +659,9 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
   if (op1->getRawDictionaryAttrs() != op2->getRawDictionaryAttrs())
     return false;
 
+  if (op1->getBlock() != op2->getBlock())
+    return false;
+
   return true;
 }
 

>From 5c339976bb985449473a30f7783b1dbc4741efd6 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 12:13:50 +0200
Subject: [PATCH 29/52] cleanup and comments

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 118 +++++++++++++-----
 1 file changed, 90 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ec1c41dbd7b69..c6e20961725c7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -48,7 +48,7 @@ struct MemoryOpGroup {
   size_t size() const { return ops.size(); }
 };
 
-static bool isReadOp(Operation *op) {
+static bool maybeReadOp(Operation *op) {
   auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
   if (!effectInterface)
     return true;
@@ -56,7 +56,7 @@ static bool isReadOp(Operation *op) {
   return effectInterface.hasEffect<MemoryEffects::Read>();
 }
 
-static bool isWriteOp(Operation *op) {
+static bool maybeWriteOp(Operation *op) {
   auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
   if (!effectInterface)
     return true;
@@ -72,10 +72,11 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
   MemoryOpGroup *currentGroup = nullptr;
 
   for (Operation &op : block) {
+    // Check if current group is interrupted by a read or write op.
     if (currentGroup) {
-      if (currentGroup->isLoadGroup() && isWriteOp(&op)) {
+      if (currentGroup->isLoadGroup() && maybeWriteOp(&op)) {
         currentGroup = nullptr;
-      } else if (currentGroup->isStoreGroup() && isReadOp(&op)) {
+      } else if (currentGroup->isStoreGroup() && maybeReadOp(&op)) {
         currentGroup = nullptr;
       }
     }
@@ -83,7 +84,7 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
     if (!isa<memref::LoadOp, memref::StoreOp>(op))
       continue;
 
-    bool isLoad = isReadOp(&op);
+    bool isLoad = maybeReadOp(&op);
     MemoryOpGroup::Type type =
         isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
 
@@ -99,6 +100,7 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
 }
 
 static Value getBase(Operation *op) {
+  assert(op && "null op");
   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
     return loadOp.getMemRef();
   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
@@ -120,6 +122,7 @@ static bool isContiguousLastDim(Value val) {
 }
 
 static ValueRange getIndices(Operation *op) {
+  assert(op && "null op");
   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
     return loadOp.getIndices();
   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
@@ -128,6 +131,7 @@ static ValueRange getIndices(Operation *op) {
 }
 
 static Type getElementType(Operation *op) {
+  assert(op && "null op");
   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
     return loadOp.getResult().getType();
   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
@@ -135,7 +139,7 @@ static Type getElementType(Operation *op) {
   return {};
 }
 
-/// Check if two indices are consecutive, i.e fastest index differs by 1.
+/// Check if two indices are consecutive, i.e index1 + 1 == index2.
 static bool isAdjacentIndices(Value idx1, Value idx2) {
   if (auto c1 = getConstantIntValue(idx1)) {
     if (auto c2 = getConstantIntValue(idx2))
@@ -153,7 +157,7 @@ static bool isAdjacentIndices(Value idx1, Value idx2) {
     }
   }
 
-  // TODO: affine.apply, etc
+  // TODO: Handle affine.apply, etc
   return false;
 }
 
@@ -173,6 +177,9 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
 /// This is done by checking if the base memrefs are the same, the last
 /// dimension is contiguous, and the element types and indices are compatible
 static bool isAdjacentOps(Operation *op1, Operation *op2) {
+  assert(op1 && "null op1");
+  assert(op2 && "null op2");
+
   Value base1 = getBase(op1);
   Value base2 = getBase(op2);
   if (base1 != base2)
@@ -181,8 +188,10 @@ static bool isAdjacentOps(Operation *op1, Operation *op2) {
   if (!isContiguousLastDim(base1))
     return false;
 
-  return getElementType(op1) == getElementType(op2) &&
-         isAdjacentIndices(getIndices(op1), getIndices(op2));
+  if (getElementType(op1) != getElementType(op2))
+    return false;
+
+  return isAdjacentIndices(getIndices(op1), getIndices(op2));
 }
 
 // Extract contiguous groups from a MemoryOpGroup
@@ -229,6 +238,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
     } while (foundMore);
 
     if (currentOps.size() <= 1) {
+      // Do not vectorize if there is only one op.
       result.pop_back();
       continue;
     }
@@ -256,9 +266,16 @@ struct SLPGraphNode {
 
   size_t size() const { return ops.size(); }
 
-  Operation *getEarliestOp() const {
+  Operation *op() const {
+    assert(!ops.empty() && "empty ops");
+    return ops.front();
+  }
+
+  Operation *getInsertionPoint() const {
+    // Find the toplogically first node, which is not nessesary the first in the
+    // `ops` as `ops` are sorted by their position in vector.
     assert(!ops.empty() && "empty node");
-    Operation *ret = ops.front();
+    Operation *ret = op();
     for (Operation *op : ArrayRef(ops).drop_front()) {
       if (op->isBeforeInBlock(ret))
         ret = op;
@@ -374,15 +391,20 @@ class SLPGraph {
       llvm::dbgs() << "Topologically sorted nodes:\n";
       for (auto *node : sortedNodes) {
         llvm::dbgs() << "  Node with " << node->size()
-                     << " operations: " << node->ops.front()->getName() << "\n";
+                     << " operations: " << node->op()->getName() << "\n";
       }
     });
 
     auto isBadNode = [&](SLPGraphNode *node) {
+      // Do not vectorize stray nodes which are not connected to any other
+      // nodes.
       return node->users.empty() && node->operands.empty();
     };
 
-    // Update vec sizes if inputs are smaller.
+    // Update node vec sizes if its inputs vec sizes are smaller.
+    // This is nedeed to handle situations when we have 3->3->4 sizes in tree.
+    // TODO: It maybe possible to reconstruct the larger vec size combining src
+    // smaller vector and scalar arg.
     for (auto *node : sortedNodes) {
       size_t size = node->size();
       for (auto *operand : node->operands)
@@ -391,14 +413,19 @@ class SLPGraph {
       node->ops.resize(size);
     }
 
-    // Remove nodes that are not good (have users or operands)
     llvm::erase_if(sortedNodes, isBadNode);
 
     IRMapping mapping;
     for (auto *node : sortedNodes) {
+      // `op` is the node with the smallest index in vector and not the
+      // nessesarily the good insertion point.
+      Operation *op = node->op();
+      Operation *ip = node->getInsertionPoint();
+      if (!ip)
+        return op->emitError("no insertion point found for node");
+
+      rewriter.setInsertionPoint(ip);
       int64_t numElements = node->size();
-      Operation *op = node->ops.front();
-      rewriter.setInsertionPoint(node->getEarliestOp());
       Location loc = op->getLoc();
 
       auto handleNonVectorInputs = [&](ValueRange operands) {
@@ -477,6 +504,10 @@ class SLPGraph {
       }
     }
 
+    LLVM_DEBUG(llvm::dbgs() << "Erasing original ops\n");
+
+    // As all nodes were cloned, we need to erase the original ops in reverse
+    // topo order to avoid invalidation users.
     for (auto *node : llvm::reverse(sortedNodes)) {
       for (Operation *op : node->ops) {
         rewriter.eraseOp(op);
@@ -498,8 +529,7 @@ class SLPGraph {
       if (!node->isRoot)
         continue;
       llvm::dbgs() << "  "
-                   << (isa<memref::LoadOp>(node->ops.front()) ? "LOAD"
-                                                              : "STORE")
+                   << (isa<memref::LoadOp>(node->op()) ? "LOAD" : "STORE")
                    << " group with " << node->size() << " operations:\n";
       for (auto *op : node->ops) {
         llvm::dbgs() << "    " << *op << "\n";
@@ -588,10 +618,41 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
 /// multiple users with the same immediate op type, this class tries to compute
 /// fingerprint for such ops based on the entire ops graph to maximize further
 /// scalar ops merging.
+///
+/// Example:
+/// ```
+///  %0 = memref.load %arg0[%c0] : memref<8xi32>
+///  %1 = memref.load %arg0[%c1] : memref<8xi32>
+///  %2 = memref.load %arg0[%c2] : memref<8xi32>
+///  %3 = memref.load %arg0[%c3] : memref<8xi32>
+///
+///  %4 = memref.load %arg1[%c0] : memref<8xi32>
+///  %5 = memref.load %arg1[%c1] : memref<8xi32>
+///  %6 = memref.load %arg1[%c2] : memref<8xi32>
+///  %7 = memref.load %arg1[%c3] : memref<8xi32>
+///
+///  %8 = arith.addi %0, %4 : i32
+///  %12 = arith.addi %0, %arg2 : i32
+///
+///  %13 = arith.addi %1, %arg3 : i32
+///  %9 = arith.addi %1, %5 : i32
+///
+///  %10 = arith.addi %2, %6 : i32
+///  %14 = arith.addi %2, %arg4 : i32
+///
+///  %15 = arith.addi %3, %arg5 : i32
+///  %11 = arith.addi %3, %7 : i32
+/// ```
+/// Here each load have multiple uses, in different order, and we want to merge
+/// them in a way that maximizes the number of merged ops.
+///
+/// To achieve this, we compute fingerprint for each op including the other
+/// operands, which will include the other loads in this example.
 struct OperationsFingerprint {
   OperationsFingerprint(const SLPGraph &graph) : graph(graph) {}
 
   Fingerprint getFingerprint(Operation *op) {
+    assert(op && "null op");
     auto it = fingerprints.find(op);
     if (it != fingerprints.end())
       return it->second;
@@ -653,6 +714,9 @@ struct OperationsFingerprint {
 };
 
 static bool isEquivalent(Operation *op1, Operation *op2) {
+  assert(op1 && "null op1");
+  assert(op2 && "null op2");
+
   if (op1->getName() != op2->getName())
     return false;
 
@@ -696,9 +760,8 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     Operation *user = use.getOwner();
     auto *existingNode = graph.getNodeForOp(user);
     if (existingNode) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "  Adding edge from " << node->ops.front()->getName()
-                 << " to " << user->getName() << "\n");
+      LLVM_DEBUG(llvm::dbgs() << "  Adding edge from " << node->op()->getName()
+                              << " to " << user->getName() << "\n");
       graph.addEdge(node, existingNode);
       return;
     }
@@ -749,9 +812,8 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
 
     auto *existingNode = graph.getNodeForOp(srcOp);
     if (existingNode) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "  Adding edge from " << srcOp->getName() << " to "
-                 << node->ops.front()->getName() << "\n");
+      LLVM_DEBUG(llvm::dbgs() << "  Adding edge from " << srcOp->getName()
+                              << " to " << node->op()->getName() << "\n");
       graph.addEdge(existingNode, node);
       return;
     }
@@ -782,11 +844,11 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
 
   while (!worklist.empty()) {
     SLPGraphNode *node = worklist.pop_back_val();
-    LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->size()
-                            << " operations, first op: "
-                            << node->ops.front()->getName() << "\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << "Processing node with " << node->size()
+               << " operations, first op: " << node->op()->getName() << "\n");
 
-    Operation *op = node->ops.front();
+    Operation *op = node->op();
     for (OpOperand &use : op->getUses())
       processUse(node, use);
 

>From d30060c54c75226c9bd9002b0f264bbbcd70d8b0 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 12:28:49 +0200
Subject: [PATCH 30/52] test

---
 mlir/test/Dialect/Vector/slp-vectorize.mlir | 58 ++++++++++++++++++++-
 1 file changed, 57 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 9d06a1faa07b2..f744098324243 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -131,7 +131,7 @@ func.func @read_write_add_index_interleaved(%arg0: memref<8xi32>, %arg1: memref<
 
 // CHECK-LABEL: func @read_read_add
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
-func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
+func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32) {
   // CHECK:     %[[C0:.*]] = arith.constant 0 : index
   // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
   // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
@@ -309,6 +309,8 @@ func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8
 //  CHECK-SAME: , %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
 func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
                                    %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
+  // Each load group have multiple 2 uses (in potentially different order)
+  // make sure we the both were vectorized.
   // CHECK:     %[[C0:.*]] = arith.constant 0 : index
   // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
   // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
@@ -357,6 +359,60 @@ func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
   return
 }
 
+// CHECK-LABEL: func @read_read_add_add
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>
+//  CHECK-SAME: , %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @read_read_add_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
+                                   %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) ->
+                                   (i32, i32, i32, i32, i32, i32, i32, i32){
+  // Each load group have multiple 2 uses (in potentially different order)
+  // make sure we the both were vectorized.
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[ADD1:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     %[[R0:.*]] = vector.extract %[[ADD1]][0] : i32 from vector<4xi32>
+  // CHECK:     %[[R1:.*]] = vector.extract %[[ADD1]][1] : i32 from vector<4xi32>
+  // CHECK:     %[[R2:.*]] = vector.extract %[[ADD1]][2] : i32 from vector<4xi32>
+  // CHECK:     %[[R3:.*]] = vector.extract %[[ADD1]][3] : i32 from vector<4xi32>
+  // CHECK:     %[[C:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] : vector<4xi32>
+  // CHECK:     %[[ADD2:.*]] = arith.addi %[[A]], %[[C]] : vector<4xi32>
+  // CHECK:     %[[R4:.*]] = vector.extract %[[ADD2]][0] : i32 from vector<4xi32>
+  // CHECK:     %[[R5:.*]] = vector.extract %[[ADD2]][1] : i32 from vector<4xi32>
+  // CHECK:     %[[R6:.*]] = vector.extract %[[ADD2]][2] : i32 from vector<4xi32>
+  // CHECK:     %[[R7:.*]] = vector.extract %[[ADD2]][3] : i32 from vector<4xi32>
+  // CHECK:     return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]] : i32, i32, i32, i32, i32, i32, i32, i32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+  %8 = arith.addi %0, %4 : i32
+  %12 = arith.addi %0, %arg2 : i32
+
+  %13 = arith.addi %1, %arg3 : i32
+  %9 = arith.addi %1, %5 : i32
+
+  %10 = arith.addi %2, %6 : i32
+  %14 = arith.addi %2, %arg4 : i32
+
+  %15 = arith.addi %3, %arg5 : i32
+  %11 = arith.addi %3, %7 : i32
+
+  return %8, %9, %10, %11, %12, %13, %14, %15 : i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+
 
 func.func private @use(i32)
 

>From a4b9529fbfb460c71465fff58d19710c5f23f39b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 12:38:22 +0200
Subject: [PATCH 31/52] test

---
 mlir/test/Dialect/Vector/slp-vectorize.mlir | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index f744098324243..293d004879fe5 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -503,3 +503,17 @@ func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1:
   call @use(%12) : (i32) -> ()
   return
 }
+
+
+// CHECK-LABEL: func @negative_single_op
+func.func @negative_single_op(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK-NOT: vector
+  %c0 = arith.constant 0 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %8 = arith.addi %0, %4 : i32
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+
+  return
+}

>From 78a5ed97251236873fe53b5613bc8872e65caece Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 13:45:19 +0200
Subject: [PATCH 32/52] Run until fixed point

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 101 +++++++++++-------
 mlir/test/Dialect/Vector/slp-vectorize.mlir   |  49 +++++++++
 2 files changed, 109 insertions(+), 41 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index c6e20961725c7..6059f8937e000 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -372,10 +372,11 @@ class SLPGraph {
     return result;
   }
 
-  /// Vectorize the operations in the graph
-  LogicalResult vectorize(IRRewriter &rewriter) {
+  /// Vectorize the operations in the graph.
+  /// Returns number of nodes vectorized or failure if failed.
+  FailureOr<size_t> vectorize(IRRewriter &rewriter) {
     if (nodes.empty())
-      return success();
+      return 0;
 
     LLVM_DEBUG(llvm::dbgs()
                << "Vectorizing SLP graph with " << nodes.size() << " nodes\n");
@@ -515,7 +516,7 @@ class SLPGraph {
     }
 
     LLVM_DEBUG(llvm::dbgs() << "Vectorization completed successfully\n");
-    return success();
+    return sortedNodes.size();
   }
 
   /// Print the graph structure
@@ -720,7 +721,7 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
   if (op1->getName() != op2->getName())
     return false;
 
-  if (op1->getRawDictionaryAttrs() != op2->getRawDictionaryAttrs())
+  if (op1->getAttrs() != op2->getAttrs())
     return false;
 
   if (op1->getBlock() != op2->getBlock())
@@ -859,48 +860,66 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
   return graph;
 }
 
+/// Try to vectorize ops in a block.
+/// Returns number of nodes vectorized or error flag if failed.
+static FailureOr<size_t> tryToVectorizeInBlock(Block &block) {
+  LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
+                          << block.getParentOp()->getName() << "\n");
+
+  // Collect memory operation groups
+  SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(block);
+
+  // Process each group to find contiguous sequences
+  SmallVector<MemoryOpGroup> rootGroups;
+  for (const auto &group : groups) {
+    SmallVector<MemoryOpGroup> contiguousGroups =
+        extractContiguousGroups(group);
+    LLVM_DEBUG({
+      llvm::dbgs() << "Found " << contiguousGroups.size()
+                   << " contiguous groups in "
+                   << (group.isLoadGroup() ? "load" : "store") << " group\n";
+      for (const auto &contigGroup : contiguousGroups) {
+        llvm::dbgs() << "  Contiguous group with " << contigGroup.size()
+                     << " operations\n";
+      }
+    });
+    rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
+  }
+
+  // Build the SLP graph from root groups
+  SLPGraph graph = buildSLPGraph(rootGroups);
+  LLVM_DEBUG(graph.print());
+
+  // Vectorize the graph
+  IRRewriter rewriter(block.getParentOp()->getContext());
+  FailureOr<size_t> numNodesVectorized = graph.vectorize(rewriter);
+  if (failed(numNodesVectorized))
+    LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
+
+  return numNodesVectorized;
+}
+
 void GreedySLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
 
-  // Walk all blocks recursively
-  op->walk([&](Block *block) -> WalkResult {
-    LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
-                            << block->getParentOp()->getName() << "\n");
-
-    // Collect memory operation groups
-    SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
-
-    // Process each group to find contiguous sequences
-    SmallVector<MemoryOpGroup> rootGroups;
-    for (const auto &group : groups) {
-      SmallVector<MemoryOpGroup> contiguousGroups =
-          extractContiguousGroups(group);
-      LLVM_DEBUG({
-        llvm::dbgs() << "Found " << contiguousGroups.size()
-                     << " contiguous groups in "
-                     << (group.isLoadGroup() ? "load" : "store") << " group\n";
-        for (const auto &contigGroup : contiguousGroups) {
-          llvm::dbgs() << "  Contiguous group with " << contigGroup.size()
-                       << " operations\n";
-        }
-      });
-      rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
-    }
+  // Run until fixed point is reached.
+  bool changed;
+  do {
+    changed = false;
+    // Walk all blocks recursively
+    if (op->walk([&](Block *block) -> WalkResult {
+            FailureOr<size_t> numNodesVectorized =
+                tryToVectorizeInBlock(*block);
+            if (failed(numNodesVectorized))
+              return WalkResult::interrupt();
 
-    // Build the SLP graph from root groups
-    SLPGraph graph = buildSLPGraph(rootGroups);
-    LLVM_DEBUG(graph.print());
+            changed = changed || *numNodesVectorized > 0;
 
-    // Vectorize the graph
-    IRRewriter rewriter(&getContext());
-    if (failed(graph.vectorize(rewriter))) {
-      LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
-      signalPassFailure();
-      return WalkResult::interrupt();
-    }
+            return WalkResult::advance();
+          }).wasInterrupted())
+      return signalPassFailure();
 
-    return WalkResult::advance();
-  });
+  } while (changed);
 }
 
 } // namespace
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 293d004879fe5..517b2318f773d 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -266,6 +266,55 @@ func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref
 }
 
 
+// CHECK-LABEL: func @read_read_add_write_attrs_mismatch
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_attrs_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+  // CHECK:     %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[V1:.*]] = vector.extract %[[V0]][2] : i32 from vector<4xi32>
+  // CHECK:     %[[V2:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
+  // CHECK:     %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[V4:.*]] = vector.extract %[[V3]][2] : i32 from vector<4xi32>
+  // CHECK:     %[[V5:.*]] = vector.extract %[[V3]][3] : i32 from vector<4xi32>
+  // CHECK:     %[[V6:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK:     %[[V7:.*]] = vector.extract_strided_slice %[[V3]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK:     %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] overflow<nsw> : vector<2xi32>
+  // CHECK:     %[[V9:.*]] = vector.from_elements %[[V1]], %[[V2]] : vector<2xi32>
+  // CHECK:     %[[V10:.*]] = vector.from_elements %[[V4]], %[[V5]] : vector<2xi32>
+  // CHECK:     %[[V11:.*]] = arith.addi %[[V9]], %[[V10]] overflow<nuw> : vector<2xi32>
+  // CHECK:     vector.store %[[V8]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+  // CHECK:     vector.store %[[V11]], %[[ARG0]][%[[C2]]] : memref<8xi32>, vector<2xi32>
+
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+  %8 = arith.addi %0, %4 overflow<nsw> : i32
+  %9 = arith.addi %1, %5 overflow<nsw> : i32
+  %10 = arith.addi %2, %6 overflow<nuw> : i32
+  %11 = arith.addi %3, %7 overflow<nuw> : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add_write_interleaved
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {

>From b875a1867746acd5bb0d941171c2908be7d3d52e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 14:16:40 +0200
Subject: [PATCH 33/52] run DCE between interations

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 24 +++++++++++--------
 1 file changed, 14 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 6059f8937e000..871611a891351 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/SHA1.h"
 
@@ -906,19 +907,22 @@ void GreedySLPVectorizerPass::runOnOperation() {
   bool changed;
   do {
     changed = false;
-    // Walk all blocks recursively
-    if (op->walk([&](Block *block) -> WalkResult {
-            FailureOr<size_t> numNodesVectorized =
-                tryToVectorizeInBlock(*block);
-            if (failed(numNodesVectorized))
-              return WalkResult::interrupt();
-
-            changed = changed || *numNodesVectorized > 0;
+    auto visitor = [&](Block *block) -> WalkResult {
+      FailureOr<size_t> numNodesVectorized = tryToVectorizeInBlock(*block);
+      if (failed(numNodesVectorized))
+        return WalkResult::interrupt();
 
-            return WalkResult::advance();
-          }).wasInterrupted())
+      changed = changed || *numNodesVectorized > 0;
+      return WalkResult::advance();
+    };
+    // Walk all blocks recursively
+    if (op->walk(visitor).wasInterrupted())
       return signalPassFailure();
 
+    // Run empty `applyPatternsGreedily` for simple DCE and folding.
+    if (changed)
+      (void)applyPatternsGreedily(
+          op, {}, GreedyRewriteConfig().enableFolding().enableConstantCSE());
   } while (changed);
 }
 

>From 6746318fb79cebe27611e1fd6b566c9562f2e353 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 15:37:57 +0200
Subject: [PATCH 34/52] comment

---
 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 871611a891351..ce6b088e0e07f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -715,9 +715,13 @@ struct OperationsFingerprint {
   DenseMap<Operation *, Fingerprint> fingerprints;
 };
 
+/// Check if two ops are equivalent for the purposes of SLP vectorization, i.e.
+/// they can be merged into single vector op.
 static bool isEquivalent(Operation *op1, Operation *op2) {
   assert(op1 && "null op1");
   assert(op2 && "null op2");
+  if (op1 == op2)
+    return true;
 
   if (op1->getName() != op2->getName())
     return false;

>From e650bbe4797b92058b7172d64b3a47d0153efc1f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 15:45:52 +0200
Subject: [PATCH 35/52] test

---
 mlir/test/Dialect/Vector/slp-vectorize.mlir | 75 ++++++++++++++++++++-
 1 file changed, 72 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 517b2318f773d..cbcd553d90f0a 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,6 +1,20 @@
 // RUN: mlir-opt %s --greedy-slp-vectorizer | FileCheck %s
 
 
+// CHECK-LABEL: func @negative_single_op
+func.func @negative_single_op(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK-NOT: vector
+  %c0 = arith.constant 0 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %8 = arith.addi %0, %4 : i32
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_write
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
@@ -554,15 +568,70 @@ func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1:
 }
 
 
-// CHECK-LABEL: func @negative_single_op
-func.func @negative_single_op(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
-  // CHECK-NOT: vector
+// CHECK-LABEL: func @negative_different_blocks
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @negative_different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+  // CHECK:     %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[V1:.*]] = vector.extract %[[V0]][2] : i32 from vector<4xi32>
+  // CHECK:     %[[V2:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
+  // CHECK:     %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[V4:.*]] = vector.extract %[[V3]][2] : i32 from vector<4xi32>
+  // CHECK:     %[[V5:.*]] = vector.extract %[[V3]][3] : i32 from vector<4xi32>
+  // CHECK:     cf.br ^bb1
+  // CHECK:   ^bb1:
+  // CHECK:     %[[V6:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK:     %[[V7:.*]] = vector.extract_strided_slice %[[V3]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK:     %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] : vector<2xi32>
+  // CHECK:     %[[V9:.*]] = vector.extract %[[V8]][0] : i32 from vector<2xi32>
+  // CHECK:     %[[V10:.*]] = vector.extract %[[V8]][1] : i32 from vector<2xi32>
+  // CHECK:     cf.br ^bb2
+  // CHECK:   ^bb2:
+  // TODO: we need to properly handle vector.extract args to vectorizre that
+  // CHECK:     %[[V11:.*]] = arith.addi %[[V1]], %[[V4]] : i32
+  // CHECK:     %[[V12:.*]] = arith.addi %[[V2]], %[[V5]] : i32
+  // CHECK:     cf.br ^bb3
+  // CHECK:   ^bb3:
+  // CHECK:     memref.store %[[V9]], %[[ARG0]][%[[C0]]] : memref<8xi32>
+  // CHECK:     memref.store %[[V10]], %[[ARG0]][%[[C1]]] : memref<8xi32>
+  // CHECK:     memref.store %[[V11]], %[[ARG0]][%[[C2]]] : memref<8xi32>
+  // CHECK:     memref.store %[[V12]], %[[ARG0]][%[[C3]]] : memref<8xi32>
+
   %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
 
   %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
   %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+  cf.br ^bb0
+
+^bb0:
   %8 = arith.addi %0, %4 : i32
+  %9 = arith.addi %1, %5 : i32
+  cf.br ^bb1
+
+^bb1:
+  %10 = arith.addi %2, %6 : i32
+  %11 = arith.addi %3, %7 : i32
+  cf.br ^bb2
+
+^bb2:
   memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
 
   return
 }

>From 08f362e04c81cf8e64ab6379d5bc8595e379a465 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 15:47:12 +0200
Subject: [PATCH 36/52] cleanup

---
 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ce6b088e0e07f..d73f35cd2599a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -924,9 +924,10 @@ void GreedySLPVectorizerPass::runOnOperation() {
       return signalPassFailure();
 
     // Run empty `applyPatternsGreedily` for simple DCE and folding.
-    if (changed)
-      (void)applyPatternsGreedily(
-          op, {}, GreedyRewriteConfig().enableFolding().enableConstantCSE());
+    if (changed) {
+      auto config = GreedyRewriteConfig().enableFolding().enableConstantCSE();
+      (void)applyPatternsGreedily(op, {}, config);
+    }
   } while (changed);
 }
 

>From 4b44d61b6be1af15b1d6d97975a94f6b38600f79 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 17:50:52 +0200
Subject: [PATCH 37/52] process extract ops

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 79 ++++++++++++++++---
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 74 +++++++----------
 2 files changed, 99 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index d73f35cd2599a..bdfaa72a36914 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -251,7 +251,19 @@ extractContiguousGroups(const MemoryOpGroup &group) {
 }
 
 static bool isVectorizable(Operation *op) {
-  return OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1;
+  if (!OpTrait::hasElementwiseMappableTraits(op))
+    return false;
+
+  if (op->getNumResults() != 1)
+    return false;
+
+  for (auto type :
+       llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
+    if (!type.isIntOrIndexOrFloat())
+      return false;
+  }
+
+  return true;
 }
 
 /// A node in the SLP graph representing a group of vectorizable operations
@@ -419,6 +431,12 @@ class SLPGraph {
 
     IRMapping mapping;
     for (auto *node : sortedNodes) {
+      LLVM_DEBUG({
+        llvm::dbgs() << "Processing node with " << node->size()
+                     << " operations\n";
+        llvm::dbgs() << "  First op: " << *node->op() << "\n";
+      });
+
       // `op` is the node with the smallest index in vector and not the
       // nessesarily the good insertion point.
       Operation *op = node->op();
@@ -500,6 +518,9 @@ class SLPGraph {
 
         mapping.map(op->getResults(), newOp->getResults());
         handleNonVectorOutputs(newOp->getResult(0));
+      } else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
+        Value val = handleVecSizeMismatch(extract.getVector());
+        mapping.map(extract.getResult(), val);
       } else {
         op->emitError("unsupported operation");
         return failure();
@@ -735,6 +756,14 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
   return true;
 }
 
+/// Get static position of the extract op, if it is 1D and static.
+static std::optional<int64_t> getExtractIndex(vector::ExtractOp extractOp) {
+  if (extractOp.getNumIndices() != 1 || extractOp.hasDynamicPosition())
+    return std::nullopt;
+
+  return extractOp.getStaticPosition().front();
+}
+
 /// Build the SLP graph starting from memory operation groups and going both
 /// top-down and bottom-up through uses/operands respectively.
 static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
@@ -824,17 +853,47 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
       return;
     }
 
-    if (!isVectorizable(srcOp))
-      return;
-
     SmallVector<Operation *> currentOps;
-    currentOps.emplace_back(srcOp);
-    for (Operation *op : ArrayRef(node->ops).drop_front()) {
-      Operation *otherOp = op->getOperand(index).getDefiningOp();
-      if (!otherOp || !isEquivalent(otherOp, srcOp))
-        break;
+    if (auto extractOp = dyn_cast<vector::ExtractOp>(srcOp)) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "  Processing vector.extract op with index "
+                 << getExtractIndex(extractOp).value_or(-1) << "\n");
+      currentOps.push_back(extractOp);
+
+      std::optional<int64_t> extractIndex = getExtractIndex(extractOp);
+      if (!extractIndex)
+        return;
+
+      Value vector = extractOp.getVector();
+      int64_t currentIndex = *extractIndex;
+      for (Operation *op : ArrayRef(node->ops).drop_front()) {
+        auto otherOp = op->getOperand(index).getDefiningOp<vector::ExtractOp>();
+        if (!otherOp || otherOp.getVector() != vector)
+          break;
+
+        std::optional<int64_t> otherExtractIndex = getExtractIndex(otherOp);
+        if (!otherExtractIndex || *otherExtractIndex != (currentIndex + 1))
+          break;
+
+        currentOps.push_back(otherOp);
+        ++currentIndex;
+      }
+    } else if (isVectorizable(srcOp)) {
+      LLVM_DEBUG(llvm::dbgs() << "  Processing vectorizable op "
+                              << srcOp->getName() << "\n");
+
+      currentOps.emplace_back(srcOp);
+      for (Operation *op : ArrayRef(node->ops).drop_front()) {
+        Operation *otherOp = op->getOperand(index).getDefiningOp();
+        if (!otherOp || !isEquivalent(otherOp, srcOp))
+          break;
 
-      currentOps.push_back(otherOp);
+        currentOps.push_back(otherOp);
+      }
+    } else {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "  Unsupported op " << srcOp->getName() << "\n");
+      return;
     }
 
     if (currentOps.size() == 1)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index cbcd553d90f0a..b27a72c6a8fe7 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -283,22 +283,18 @@ func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref
 // CHECK-LABEL: func @read_read_add_write_attrs_mismatch
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add_write_attrs_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
-  // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
-  // CHECK:     %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
-  // CHECK:     %[[V1:.*]] = vector.extract %[[V0]][2] : i32 from vector<4xi32>
-  // CHECK:     %[[V2:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
-  // CHECK:     %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
-  // CHECK:     %[[V4:.*]] = vector.extract %[[V3]][2] : i32 from vector<4xi32>
-  // CHECK:     %[[V5:.*]] = vector.extract %[[V3]][3] : i32 from vector<4xi32>
-  // CHECK:     %[[V6:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
-  // CHECK:     %[[V7:.*]] = vector.extract_strided_slice %[[V3]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
-  // CHECK:     %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] overflow<nsw> : vector<2xi32>
-  // CHECK:     %[[V9:.*]] = vector.from_elements %[[V1]], %[[V2]] : vector<2xi32>
-  // CHECK:     %[[V10:.*]] = vector.from_elements %[[V4]], %[[V5]] : vector<2xi32>
-  // CHECK:     %[[V11:.*]] = arith.addi %[[V9]], %[[V10]] overflow<nuw> : vector<2xi32>
-  // CHECK:     vector.store %[[V8]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
-  // CHECK:     vector.store %[[V11]], %[[ARG0]][%[[C2]]] : memref<8xi32>, vector<2xi32>
+    // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+    // CHECK:     %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+    // CHECK:     %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+    // CHECK:     %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     %[[V5:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     %[[V6:.*]] = arith.addi %[[V4]], %[[V5]] overflow<nsw> : vector<2xi32>
+    // CHECK:     %[[V7:.*]] = arith.addi %[[V1]], %[[V3]] overflow<nuw> : vector<2xi32>
+    // CHECK:     vector.store %[[V6]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+    // CHECK:     vector.store %[[V7]], %[[ARG0]][%[[C2]]] : memref<8xi32>, vector<2xi32>
 
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -571,34 +567,24 @@ func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1:
 // CHECK-LABEL: func @negative_different_blocks
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @negative_different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-  // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-  // CHECK:     %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
-  // CHECK:     %[[V1:.*]] = vector.extract %[[V0]][2] : i32 from vector<4xi32>
-  // CHECK:     %[[V2:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
-  // CHECK:     %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
-  // CHECK:     %[[V4:.*]] = vector.extract %[[V3]][2] : i32 from vector<4xi32>
-  // CHECK:     %[[V5:.*]] = vector.extract %[[V3]][3] : i32 from vector<4xi32>
-  // CHECK:     cf.br ^bb1
-  // CHECK:   ^bb1:
-  // CHECK:     %[[V6:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
-  // CHECK:     %[[V7:.*]] = vector.extract_strided_slice %[[V3]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
-  // CHECK:     %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] : vector<2xi32>
-  // CHECK:     %[[V9:.*]] = vector.extract %[[V8]][0] : i32 from vector<2xi32>
-  // CHECK:     %[[V10:.*]] = vector.extract %[[V8]][1] : i32 from vector<2xi32>
-  // CHECK:     cf.br ^bb2
-  // CHECK:   ^bb2:
-  // TODO: we need to properly handle vector.extract args to vectorizre that
-  // CHECK:     %[[V11:.*]] = arith.addi %[[V1]], %[[V4]] : i32
-  // CHECK:     %[[V12:.*]] = arith.addi %[[V2]], %[[V5]] : i32
-  // CHECK:     cf.br ^bb3
-  // CHECK:   ^bb3:
-  // CHECK:     memref.store %[[V9]], %[[ARG0]][%[[C0]]] : memref<8xi32>
-  // CHECK:     memref.store %[[V10]], %[[ARG0]][%[[C1]]] : memref<8xi32>
-  // CHECK:     memref.store %[[V11]], %[[ARG0]][%[[C2]]] : memref<8xi32>
-  // CHECK:     memref.store %[[V12]], %[[ARG0]][%[[C3]]] : memref<8xi32>
+    // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+    // CHECK:     %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+    // CHECK:     %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+    // CHECK:     %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     cf.br ^bb1
+    // CHECK:   ^bb1:
+    // CHECK:     %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     %[[V5:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     %[[V6:.*]] = arith.addi %[[V4]], %[[V5]] : vector<2xi32>
+    // CHECK:     cf.br ^bb2
+    // CHECK:   ^bb2:
+    // CHECK:     %[[V7:.*]] = arith.addi %[[V1]], %[[V3]] : vector<2xi32>
+    // CHECK:     cf.br ^bb3
+    // CHECK:   ^bb3:
+    // CHECK:     vector.store %[[V6]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+    // CHECK:     vector.store %[[V7]], %[[ARG0]][%[[C2]]] : memref<8xi32>, vector<2xi32>
 
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index

>From 293b27a2b1c32ad4e36f4923b7d720c92fec2a20 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 19:53:47 +0200
Subject: [PATCH 38/52] handle vec size and domination

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 409 +++++++++++-------
 mlir/test/Dialect/Vector/slp-vectorize.mlir   |  22 +-
 2 files changed, 267 insertions(+), 164 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index bdfaa72a36914..e7c550b64e71f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -11,11 +11,13 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
@@ -266,6 +268,13 @@ static bool isVectorizable(Operation *op) {
   return true;
 }
 
+/// Get the next operation in the block, assuming `op` is not a terminator.
+static Operation *nextOp(Operation *op) {
+  assert(op && "null op");
+  auto it = op->getIterator();
+  return &*std::next(it);
+}
+
 /// A node in the SLP graph representing a group of vectorizable operations
 struct SLPGraphNode {
   SmallVector<Operation *> ops;
@@ -293,6 +302,31 @@ struct SLPGraphNode {
       if (op->isBeforeInBlock(ret))
         ret = op;
     }
+
+    for (Operation *op : ops) {
+      for (Value opOperand : op->getOperands()) {
+        Operation *defOp = opOperand.getDefiningOp();
+        if (!defOp || defOp->getBlock() != ret->getBlock())
+          continue;
+
+        Operation *next = nextOp(defOp);
+        if (ret->isBeforeInBlock(next))
+          ret = next;
+      }
+    }
+
+    // Try to adjust insertion point to satisfy dominance relations with
+    // operands.
+    for (SLPGraphNode *operand : operands) {
+      Operation *ip = operand->getInsertionPoint();
+      if (!ip)
+        return nullptr;
+
+      Operation *next = nextOp(ip);
+      if (next->getBlock() == ret->getBlock() && ret->isBeforeInBlock(next))
+        ret = next;
+    }
+
     return ret;
   }
 };
@@ -387,159 +421,9 @@ class SLPGraph {
 
   /// Vectorize the operations in the graph.
   /// Returns number of nodes vectorized or failure if failed.
-  FailureOr<size_t> vectorize(IRRewriter &rewriter) {
-    if (nodes.empty())
-      return 0;
-
-    LLVM_DEBUG(llvm::dbgs()
-               << "Vectorizing SLP graph with " << nodes.size() << " nodes\n");
-
-    // Get topologically sorted nodes
-    SmallVector<SLPGraphNode *> sortedNodes = topologicalSort();
-    if (sortedNodes.empty()) {
-      LLVM_DEBUG(llvm::dbgs() << "Failed to topologically sort nodes\n");
-      return failure();
-    }
-
-    LLVM_DEBUG({
-      llvm::dbgs() << "Topologically sorted nodes:\n";
-      for (auto *node : sortedNodes) {
-        llvm::dbgs() << "  Node with " << node->size()
-                     << " operations: " << node->op()->getName() << "\n";
-      }
-    });
-
-    auto isBadNode = [&](SLPGraphNode *node) {
-      // Do not vectorize stray nodes which are not connected to any other
-      // nodes.
-      return node->users.empty() && node->operands.empty();
-    };
-
-    // Update node vec sizes if its inputs vec sizes are smaller.
-    // This is nedeed to handle situations when we have 3->3->4 sizes in tree.
-    // TODO: It maybe possible to reconstruct the larger vec size combining src
-    // smaller vector and scalar arg.
-    for (auto *node : sortedNodes) {
-      size_t size = node->size();
-      for (auto *operand : node->operands)
-        size = std::min(size, operand->size());
-
-      node->ops.resize(size);
-    }
-
-    llvm::erase_if(sortedNodes, isBadNode);
-
-    IRMapping mapping;
-    for (auto *node : sortedNodes) {
-      LLVM_DEBUG({
-        llvm::dbgs() << "Processing node with " << node->size()
-                     << " operations\n";
-        llvm::dbgs() << "  First op: " << *node->op() << "\n";
-      });
-
-      // `op` is the node with the smallest index in vector and not the
-      // nessesarily the good insertion point.
-      Operation *op = node->op();
-      Operation *ip = node->getInsertionPoint();
-      if (!ip)
-        return op->emitError("no insertion point found for node");
-
-      rewriter.setInsertionPoint(ip);
-      int64_t numElements = node->size();
-      Location loc = op->getLoc();
-
-      auto handleNonVectorInputs = [&](ValueRange operands) {
-        for (auto [i, operand] : llvm::enumerate(operands)) {
-          if (getNodeForOp(operand.getDefiningOp()))
-            continue;
-
-          SmallVector<Value> args;
-          for (Operation *defOp : node->ops)
-            args.push_back(defOp->getOperand(i));
-
-          auto vecType = VectorType::get(numElements, operand.getType());
-          Value vector =
-              rewriter.create<vector::FromElementsOp>(loc, vecType, args);
-          mapping.map(operand, vector);
-        }
-      };
-
-      auto handleNonVectorOutputs = [&](Value newResult) {
-        for (auto [i, result] : llvm::enumerate(node->ops)) {
-          for (OpOperand &use : result->getUses()) {
-            Operation *useOwner = use.getOwner();
-            if (getNodeForOp(useOwner))
-              continue;
-
-            Value elem = rewriter.create<vector::ExtractOp>(loc, newResult, i);
-            use.set(elem);
-          }
-        }
-      };
-
-      auto handleVecSizeMismatch = [&](Value arg) -> Value {
-        auto srcType = cast<VectorType>(arg.getType());
-        assert(srcType.getRank() == 1);
-        if (srcType.getDimSize(0) == numElements)
-          return arg;
-
-        return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, 0,
-                                                              numElements, 1);
-      };
-
-      if (auto load = dyn_cast<memref::LoadOp>(op)) {
-        auto vecType =
-            VectorType::get(numElements, load.getMemRefType().getElementType());
-        Value result = rewriter.create<vector::LoadOp>(
-            loc, vecType, load.getMemRef(), load.getIndices());
-        mapping.map(load.getResult(), result);
-        handleNonVectorOutputs(result);
-      } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
-        handleNonVectorInputs(store.getValueToStore());
-        Value val = mapping.lookupOrDefault(store.getValueToStore());
-        val = handleVecSizeMismatch(val);
-        rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
-                                         store.getIndices());
-      } else if (isVectorizable(op)) {
-        handleNonVectorInputs(op->getOperands());
-        Operation *newOp = rewriter.clone(*op, mapping);
-        auto resVectorType =
-            VectorType::get(numElements, op->getResultTypes().front());
-
-        {
-          OpBuilder::InsertionGuard guard(rewriter);
-          rewriter.setInsertionPoint(newOp);
-          for (OpOperand &operand : newOp->getOpOperands()) {
-            Value newOperand = handleVecSizeMismatch(operand.get());
-            operand.set(newOperand);
-          }
-        }
-        newOp->getResult(0).setType(resVectorType);
-
-        mapping.map(op->getResults(), newOp->getResults());
-        handleNonVectorOutputs(newOp->getResult(0));
-      } else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
-        Value val = handleVecSizeMismatch(extract.getVector());
-        mapping.map(extract.getResult(), val);
-      } else {
-        op->emitError("unsupported operation");
-        return failure();
-      }
-    }
-
-    LLVM_DEBUG(llvm::dbgs() << "Erasing original ops\n");
-
-    // As all nodes were cloned, we need to erase the original ops in reverse
-    // topo order to avoid invalidation users.
-    for (auto *node : llvm::reverse(sortedNodes)) {
-      for (Operation *op : node->ops) {
-        rewriter.eraseOp(op);
-      }
-    }
-
-    LLVM_DEBUG(llvm::dbgs() << "Vectorization completed successfully\n");
-    return sortedNodes.size();
-  }
+  FailureOr<size_t>
+  vectorize(IRRewriter &rewriter,
+            llvm::function_ref<bool(Type, size_t)> isValidVecType);
 
   /// Print the graph structure
   [[maybe_unused]] void print() const {
@@ -736,6 +620,31 @@ struct OperationsFingerprint {
   DenseMap<Operation *, Fingerprint> fingerprints;
 };
 
+/// Check if op input/output types can be vectorized.
+static bool
+checkOpVecType(SLPGraphNode *node,
+               llvm::function_ref<bool(Type, size_t)> isValidVecType) {
+  Operation *op = node->op();
+  size_t size = node->size();
+  if (Type elementType = getElementType(op))
+    return isValidVecType(elementType, size);
+
+  if (isVectorizable(op)) {
+    for (auto type :
+         llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
+      if (!isValidVecType(type, size))
+        return false;
+    }
+    return true;
+  }
+
+  if (auto extract = dyn_cast<vector::ExtractOp>(op))
+    return isValidVecType(extract.getResult().getType(), size);
+
+  LLVM_DEBUG(llvm::dbgs() << "Unsupported op " << op->getName() << "\n");
+  return false;
+}
+
 /// Check if two ops are equivalent for the purposes of SLP vectorization, i.e.
 /// they can be merged into single vector op.
 static bool isEquivalent(Operation *op1, Operation *op2) {
@@ -924,9 +833,176 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
   return graph;
 }
 
+FailureOr<size_t>
+SLPGraph::vectorize(IRRewriter &rewriter,
+                    llvm::function_ref<bool(Type, size_t)> isValidVecType) {
+  if (nodes.empty())
+    return 0;
+
+  LLVM_DEBUG(llvm::dbgs() << "Vectorizing SLP graph with " << nodes.size()
+                          << " nodes\n");
+
+  // Get topologically sorted nodes
+  SmallVector<SLPGraphNode *> sortedNodes = topologicalSort();
+  if (sortedNodes.empty()) {
+    LLVM_DEBUG(llvm::dbgs() << "Failed to topologically sort nodes\n");
+    return failure();
+  }
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "Topologically sorted nodes:\n";
+    for (auto *node : sortedNodes) {
+      llvm::dbgs() << "  Node with " << node->size()
+                   << " operations: " << node->op()->getName() << "\n";
+    }
+  });
+
+  auto isBadNode = [&](SLPGraphNode *node) {
+    // Do not vectorize stray nodes which are not connected to any other
+    // nodes.
+    return (node->users.empty() && node->operands.empty()) || node->size() <= 1;
+  };
+
+  // Update node vec sizes if its inputs vec sizes are smaller.
+  // This is nedeed to handle situations when we have 3->3->4 sizes in tree.
+  // TODO: It maybe possible to reconstruct the larger vec size combining src
+  // smaller vector and scalar arg.
+  for (auto *node : sortedNodes) {
+    size_t size = node->size();
+    for (auto *operand : node->operands)
+      size = std::min(size, operand->size());
+
+    node->ops.resize(size);
+
+    while (node->size() > 1) {
+      if (checkOpVecType(node, isValidVecType))
+        break;
+
+      node->ops.pop_back();
+    }
+  }
+
+  llvm::erase_if(sortedNodes, isBadNode);
+
+  IRMapping mapping;
+  for (auto *node : sortedNodes) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Processing node with " << node->size()
+                   << " operations\n";
+      llvm::dbgs() << "  First op: " << *node->op() << "\n";
+    });
+
+    // `op` is the node with the smallest index in vector and not the
+    // nessesarily the good insertion point.
+    Operation *op = node->op();
+    Operation *ip = node->getInsertionPoint();
+    if (!ip)
+      return op->emitError("no insertion point found for node");
+
+    LLVM_DEBUG(llvm::dbgs() << "  Insertion point: " << *ip << "\n");
+
+    rewriter.setInsertionPoint(ip);
+    int64_t numElements = node->size();
+    Location loc = op->getLoc();
+
+    auto handleNonVectorInputs = [&](ValueRange operands) {
+      for (auto [i, operand] : llvm::enumerate(operands)) {
+        if (getNodeForOp(operand.getDefiningOp()))
+          continue;
+
+        SmallVector<Value> args;
+        for (Operation *defOp : node->ops)
+          args.push_back(defOp->getOperand(i));
+
+        auto vecType = VectorType::get(numElements, operand.getType());
+        Value vector =
+            rewriter.create<vector::FromElementsOp>(loc, vecType, args);
+        mapping.map(operand, vector);
+      }
+    };
+
+    auto handleNonVectorOutputs = [&](Value newResult) {
+      for (auto [i, result] : llvm::enumerate(node->ops)) {
+        for (OpOperand &use : result->getUses()) {
+          Operation *useOwner = use.getOwner();
+          if (getNodeForOp(useOwner))
+            continue;
+
+          Value elem = rewriter.create<vector::ExtractOp>(loc, newResult, i);
+          use.set(elem);
+        }
+      }
+    };
+
+    auto handleVecSizeMismatch = [&](Value arg) -> Value {
+      auto srcType = cast<VectorType>(arg.getType());
+      assert(srcType.getRank() == 1);
+      if (srcType.getDimSize(0) == numElements)
+        return arg;
+
+      return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, 0,
+                                                            numElements, 1);
+    };
+
+    if (auto load = dyn_cast<memref::LoadOp>(op)) {
+      auto vecType =
+          VectorType::get(numElements, load.getMemRefType().getElementType());
+      Value result = rewriter.create<vector::LoadOp>(
+          loc, vecType, load.getMemRef(), load.getIndices());
+      mapping.map(load.getResult(), result);
+      handleNonVectorOutputs(result);
+    } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
+      handleNonVectorInputs(store.getValueToStore());
+      Value val = mapping.lookupOrDefault(store.getValueToStore());
+      val = handleVecSizeMismatch(val);
+      rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
+                                       store.getIndices());
+    } else if (isVectorizable(op)) {
+      handleNonVectorInputs(op->getOperands());
+      Operation *newOp = rewriter.clone(*op, mapping);
+      auto resVectorType =
+          VectorType::get(numElements, op->getResultTypes().front());
+
+      {
+        OpBuilder::InsertionGuard guard(rewriter);
+        rewriter.setInsertionPoint(newOp);
+        for (OpOperand &operand : newOp->getOpOperands()) {
+          Value newOperand = handleVecSizeMismatch(operand.get());
+          operand.set(newOperand);
+        }
+      }
+      newOp->getResult(0).setType(resVectorType);
+
+      mapping.map(op->getResults(), newOp->getResults());
+      handleNonVectorOutputs(newOp->getResult(0));
+    } else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
+      Value val = handleVecSizeMismatch(extract.getVector());
+      mapping.map(extract.getResult(), val);
+    } else {
+      op->emitError("unsupported operation");
+      return failure();
+    }
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "Erasing original ops\n");
+
+  // As all nodes were cloned, we need to erase the original ops in reverse
+  // topo order to avoid invalidation users.
+  for (auto *node : llvm::reverse(sortedNodes)) {
+    for (Operation *op : node->ops) {
+      rewriter.eraseOp(op);
+    }
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "Vectorization completed successfully\n");
+  return sortedNodes.size();
+}
+
 /// Try to vectorize ops in a block.
 /// Returns number of nodes vectorized or error flag if failed.
-static FailureOr<size_t> tryToVectorizeInBlock(Block &block) {
+static FailureOr<size_t>
+tryToVectorizeInBlock(Block &block,
+                      llvm::function_ref<bool(Type, size_t)> isValidVecType) {
   LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
                           << block.getParentOp()->getName() << "\n");
 
@@ -956,22 +1032,42 @@ static FailureOr<size_t> tryToVectorizeInBlock(Block &block) {
 
   // Vectorize the graph
   IRRewriter rewriter(block.getParentOp()->getContext());
-  FailureOr<size_t> numNodesVectorized = graph.vectorize(rewriter);
+  FailureOr<size_t> numNodesVectorized =
+      graph.vectorize(rewriter, isValidVecType);
   if (failed(numNodesVectorized))
     LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
 
   return numNodesVectorized;
 }
 
+static bool isPow2(size_t size) {
+  assert(size > 0);
+  return (size & (size - 1)) == 0;
+}
+
 void GreedySLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
 
+  const DataLayout *dataLayout = nullptr;
+  auto isValidVecType = [&](Type type, size_t count) {
+    if (!isPow2(count))
+      return false;
+
+    if (!dataLayout)
+      dataLayout = &getAnalysis<DataLayoutAnalysis>().getAtOrAbove(op);
+
+    auto sizeInBits = dataLayout->getTypeSizeInBits(type);
+
+    return sizeInBits * count <= 256;
+  };
+
   // Run until fixed point is reached.
   bool changed;
   do {
     changed = false;
     auto visitor = [&](Block *block) -> WalkResult {
-      FailureOr<size_t> numNodesVectorized = tryToVectorizeInBlock(*block);
+      FailureOr<size_t> numNodesVectorized =
+          tryToVectorizeInBlock(*block, isValidVecType);
       if (failed(numNodesVectorized))
         return WalkResult::interrupt();
 
@@ -987,6 +1083,7 @@ void GreedySLPVectorizerPass::runOnOperation() {
       auto config = GreedyRewriteConfig().enableFolding().enableConstantCSE();
       (void)applyPatternsGreedily(op, {}, config);
     }
+    op->dump();
   } while (changed);
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index b27a72c6a8fe7..c363fe9491ee3 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -479,16 +479,22 @@ func.func private @use(i32)
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
   // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[C2:.*]] = arith.constant 2 : index
   // CHECK:     %[[C3:.*]] = arith.constant 3 : index
-  // CHECK:     %[[V0:.*]] = memref.load %[[ARG0]][%[[C3]]] : memref<8xi32>
-  // CHECK:     %[[V1:.*]] = memref.load %[[ARG1]][%[[C3]]] : memref<8xi32>
+  // CHECK:     %[[V0:.*]] = memref.load %arg0[%[[C3]]] : memref<8xi32>
+  // CHECK:     %[[V1:.*]] = memref.load %arg1[%[[C3]]] : memref<8xi32>
   // CHECK:     call @use(%[[V0]]) : (i32) -> ()
-  // CHECK:     %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : i32
-  // CHECK:     %[[V3:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
-  // CHECK:     %[[V4:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<3xi32>
-  // CHECK:     %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : vector<3xi32>
-  // CHECK:     vector.store %[[V5]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
-  // CHECK:     memref.store %[[V2]], %[[ARG0]][%[[C3]]] : memref<8xi32>
+  // CHECK:     %[[V2:.*]] = vector.load %arg0[%[[C0]]] : memref<8xi32>, vector<2xi32>
+  // CHECK:     %[[V3:.*]] = vector.load %arg1[%[[C0]]] : memref<8xi32>, vector<2xi32>
+  // CHECK:     %[[V4:.*]] = memref.load %arg0[%[[C2]]] : memref<8xi32>
+  // CHECK:     %[[V5:.*]] = memref.load %arg1[%[[C2]]] : memref<8xi32>
+  // CHECK:     %[[V6:.*]] = vector.from_elements %[[V4]], %[[V0]] : vector<2xi32>
+  // CHECK:     %[[V7:.*]] = vector.from_elements %[[V5]], %[[V1]] : vector<2xi32>
+  // CHECK:     %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] : vector<2xi32>
+  // CHECK:     %[[V9:.*]] = arith.addi %[[V2]], %[[V3]] : vector<2xi32>
+  // CHECK:     vector.store %[[V9]], %arg0[%[[C0]]] : memref<8xi32>, vector<2xi32>
+  // CHECK:     vector.store %[[V8]], %arg0[%[[C2]]] : memref<8xi32>, vector<2xi32>
+
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index

>From 8db555609147d308d14ac317fb9b20ba4ad11828 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 19:59:51 +0200
Subject: [PATCH 39/52] cache insertion point

---
 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e7c550b64e71f..ced04e406edbc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -280,6 +280,7 @@ struct SLPGraphNode {
   SmallVector<Operation *> ops;
   SmallVector<SLPGraphNode *> users;
   SmallVector<SLPGraphNode *> operands;
+  Operation *insertionPoint = nullptr;
   bool isRoot = false;
 
   SLPGraphNode() = default;
@@ -293,10 +294,13 @@ struct SLPGraphNode {
     return ops.front();
   }
 
-  Operation *getInsertionPoint() const {
+  Operation *getInsertionPoint() {
+    assert(!ops.empty() && "empty node");
+    if (insertionPoint)
+      return insertionPoint;
+
     // Find the toplogically first node, which is not nessesary the first in the
     // `ops` as `ops` are sorted by their position in vector.
-    assert(!ops.empty() && "empty node");
     Operation *ret = op();
     for (Operation *op : ArrayRef(ops).drop_front()) {
       if (op->isBeforeInBlock(ret))
@@ -327,6 +331,7 @@ struct SLPGraphNode {
         ret = next;
     }
 
+    insertionPoint = ret;
     return ret;
   }
 };

>From e4c1589f545dd8f2a0b85aa536b95c2464d16dfb Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 20:33:56 +0200
Subject: [PATCH 40/52] test

---
 .../Vector/Transforms/SLPVectorizer.cpp       |  2 +-
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 63 +++++++++++++++++++
 2 files changed, 64 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ced04e406edbc..a73da4a93dec3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -995,6 +995,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
   // topo order to avoid invalidation users.
   for (auto *node : llvm::reverse(sortedNodes)) {
     for (Operation *op : node->ops) {
+      LLVM_DEBUG(llvm::dbgs() << "Erasing op: " << *op << "\n");
       rewriter.eraseOp(op);
     }
   }
@@ -1088,7 +1089,6 @@ void GreedySLPVectorizerPass::runOnOperation() {
       auto config = GreedyRewriteConfig().enableFolding().enableConstantCSE();
       (void)applyPatternsGreedily(op, {}, config);
     }
-    op->dump();
   } while (changed);
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index c363fe9491ee3..0a40037b015b1 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -245,6 +245,69 @@ func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
 }
 
 
+// CHECK-LABEL: func @read_read_add_write_seven
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xindex>, %[[ARG1:.*]]: memref<8xindex>)
+func.func @read_read_add_write_seven(%arg0: memref<8xindex>, %arg1: memref<8xindex>) {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+  // CHECK:     %[[A0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xindex>, vector<4xindex>
+  // CHECK:     %[[A1:.*]] = vector.load %[[ARG0]][%[[C4]]] : memref<8xindex>, vector<2xindex>
+  // CHECK:     %[[A2:.*]] = memref.load %[[ARG0]][%[[C6]]] : memref<8xindex>
+  // CHECK:     %[[B0:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xindex>, vector<4xindex>
+  // CHECK:     %[[B1:.*]] = vector.load %[[ARG1]][%[[C4]]] : memref<8xindex>, vector<2xindex>
+  // CHECK:     %[[B2:.*]] = memref.load %[[ARG1]][%[[C6]]] : memref<8xindex>
+  // CHECK:     %[[RES0:.*]] = arith.addi %[[A0]], %[[B0]] : vector<4xindex>
+  // CHECK:     %[[RES1:.*]] = arith.addi %[[A1]], %[[B1]] : vector<2xindex>
+  // CHECK:     %[[RES2:.*]] = arith.addi %[[A2]], %[[B2]] : index
+  // CHECK:     vector.store %[[RES0]], %[[ARG0]][%[[C0]]] : memref<8xindex>, vector<4xindex>
+  // CHECK:     vector.store %[[RES1]], %[[ARG0]][%[[C4]]] : memref<8xindex>, vector<2xindex>
+  // CHECK:     memref.store %[[RES2]], %[[ARG0]][%[[C6]]] : memref<8xindex>
+
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %c5 = arith.constant 5 : index
+  %c6 = arith.constant 6 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xindex>
+  %1 = memref.load %arg0[%c1] : memref<8xindex>
+  %2 = memref.load %arg0[%c2] : memref<8xindex>
+  %3 = memref.load %arg0[%c3] : memref<8xindex>
+  %4 = memref.load %arg0[%c4] : memref<8xindex>
+  %5 = memref.load %arg0[%c5] : memref<8xindex>
+  %6 = memref.load %arg0[%c6] : memref<8xindex>
+
+  %7 = memref.load %arg1[%c0] : memref<8xindex>
+  %8 = memref.load %arg1[%c1] : memref<8xindex>
+  %9 = memref.load %arg1[%c2] : memref<8xindex>
+  %10 = memref.load %arg1[%c3] : memref<8xindex>
+  %11 = memref.load %arg1[%c4] : memref<8xindex>
+  %12 = memref.load %arg1[%c5] : memref<8xindex>
+  %13 = memref.load %arg1[%c6] : memref<8xindex>
+
+  %14 = arith.addi %0, %7 : index
+  %15 = arith.addi %1, %8 : index
+  %16 = arith.addi %2, %9 : index
+  %17 = arith.addi %3, %10 : index
+  %18 = arith.addi %4, %11 : index
+  %19 = arith.addi %5, %12 : index
+  %20 = arith.addi %6, %13 : index
+
+  memref.store %14, %arg0[%c0] : memref<8xindex>
+  memref.store %15, %arg0[%c1] : memref<8xindex>
+  memref.store %16, %arg0[%c2] : memref<8xindex>
+  memref.store %17, %arg0[%c3] : memref<8xindex>
+  memref.store %18, %arg0[%c4] : memref<8xindex>
+  memref.store %19, %arg0[%c5] : memref<8xindex>
+  memref.store %20, %arg0[%c6] : memref<8xindex>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add_write_size_mismatch
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {

>From fb190bf6e136c85dee1b748640fc405bb5b34f6b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 20:47:44 +0200
Subject: [PATCH 41/52] pass option

---
 mlir/include/mlir/Dialect/Vector/Transforms/Passes.td | 10 ++++++----
 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp  |  4 +++-
 mlir/test/Dialect/Vector/slp-vectorize.mlir           |  2 +-
 3 files changed, 10 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index d5c31c9f78409..970e488d3494d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -44,12 +44,14 @@ def GreedySLPVectorizer : Pass<"greedy-slp-vectorizer"> {
 
     This is greedy vectorizer, it doesn't have any cost model (yet) and it tries
     to create vector ops if we have at least 2 potential ops.
-
-    It doesn't check if target actually supports resulted vectors either, user
-    will need a follow up pass which will split large and/or unaliggned vectors
-    into sizes actually supported by the target.
   }];
   let dependentDialects = ["mlir::vector::VectorDialect"];
+
+  let options = [
+    Option<"maxVectorBitwidth", "max-vector-bitwidth", "unsigned",
+      /*default=*/"std::numeric_limits<unsigned>::max()",
+      "Maximum supported vector bitwidth">,
+  ];
 }
 
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index a73da4a93dec3..0e650359d3339 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -515,6 +515,8 @@ class SLPGraph {
 struct GreedySLPVectorizerPass
     : public mlir::vector::impl::GreedySLPVectorizerBase<
           GreedySLPVectorizerPass> {
+  using GreedySLPVectorizerBase::GreedySLPVectorizerBase;
+
   void runOnOperation() override;
 };
 
@@ -1064,7 +1066,7 @@ void GreedySLPVectorizerPass::runOnOperation() {
 
     auto sizeInBits = dataLayout->getTypeSizeInBits(type);
 
-    return sizeInBits * count <= 256;
+    return sizeInBits * count <= this->maxVectorBitwidth;
   };
 
   // Run until fixed point is reached.
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 0a40037b015b1..262db81e16f21 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --greedy-slp-vectorizer | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(greedy-slp-vectorizer{max-vector-bitwidth=256}))' | FileCheck %s
 
 
 // CHECK-LABEL: func @negative_single_op

>From f542644fe567f051d18882d08c24dac718943584 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 20:54:35 +0200
Subject: [PATCH 42/52] fix test name

---
 mlir/test/Dialect/Vector/slp-vectorize.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 262db81e16f21..75b77561ed891 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -633,9 +633,9 @@ func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1:
 }
 
 
-// CHECK-LABEL: func @negative_different_blocks
+// CHECK-LABEL: func @different_blocks
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
-func.func @negative_different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+func.func @different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
     // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
     // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
     // CHECK:     %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>

>From c89d7c6ee7b7168a628e2153f318421001b7a224 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 21:04:44 +0200
Subject: [PATCH 43/52] cleanup

---
 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 0e650359d3339..07bba1093d741 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -1031,7 +1031,7 @@ tryToVectorizeInBlock(Block &block,
                      << " operations\n";
       }
     });
-    rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
+    rootGroups.append(contiguousGroups);
   }
 
   // Build the SLP graph from root groups

>From d8fabe64323f95288b2debec47b411ce055cfeac Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 27 May 2025 00:14:34 +0200
Subject: [PATCH 44/52] AffineApplyOp index support

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 26 +++++++++++++++-
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 31 +++++++++++++++++++
 2 files changed, 56 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 07bba1093d741..892a8807d70e4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/DataLayoutAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -142,6 +143,27 @@ static Type getElementType(Operation *op) {
   return {};
 }
 
+static bool isAdjacentAffineMapIndices(Value idx1, Value idx2) {
+  auto applyOp1 = idx1.getDefiningOp<affine::AffineApplyOp>();
+  if (!applyOp1)
+    return false;
+
+  auto applyOp2 = idx2.getDefiningOp<affine::AffineApplyOp>();
+  if (!applyOp2)
+    return false;
+
+  if (applyOp1.getOperands() != applyOp2.getOperands())
+    return false;
+
+  AffineExpr expr1 = applyOp1.getAffineMap().getResult(0);
+  AffineExpr expr2 = applyOp2.getAffineMap().getResult(0);
+  auto diff =
+      simplifyAffineExpr(expr2 - expr1, 0, applyOp1.getOperands().size());
+
+  auto diffConst = dyn_cast<AffineConstantExpr>(diff);
+  return diffConst && diffConst.getValue() == 1;
+}
+
 /// Check if two indices are consecutive, i.e index1 + 1 == index2.
 static bool isAdjacentIndices(Value idx1, Value idx2) {
   if (auto c1 = getConstantIntValue(idx1)) {
@@ -160,7 +182,9 @@ static bool isAdjacentIndices(Value idx1, Value idx2) {
     }
   }
 
-  // TODO: Handle affine.apply, etc
+  if (isAdjacentAffineMapIndices(idx1, idx2))
+    return true;
+
   return false;
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 75b77561ed891..4328926f8071f 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -143,6 +143,37 @@ func.func @read_write_add_index_interleaved(%arg0: memref<8xi32>, %arg1: memref<
 }
 
 
+#map0 = affine_map<()[s0, s1] -> (s1 * s0)>
+#map1 = affine_map<()[s0, s1] -> (s1 * s0 + 1)>
+#map2 = affine_map<()[s0, s1] -> (s1 * s0 + 2)>
+#map3 = affine_map<()[s0, s1] -> (s1 * s0 + 3)>
+
+// CHECK-LABEL: func @read_write_affine_apply
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @read_write_affine_apply(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index, %arg3: index) {
+  // CHECK:     %[[IDX:.*]] = affine.apply #{{.*}}()[%[[ARG2]], %[[ARG3]]]
+  // CHECK:     %[[RES:.*]] = vector.load %[[ARG0]][%[[IDX]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[IDX]]] : memref<8xi32>, vector<4xi32>
+
+  %ind0 = affine.apply #map0()[%arg2, %arg3]
+  %ind1 = affine.apply #map1()[%arg2, %arg3]
+  %ind2 = affine.apply #map2()[%arg2, %arg3]
+  %ind3 = affine.apply #map3()[%arg2, %arg3]
+
+  %0 = memref.load %arg0[%ind0] : memref<8xi32>
+  %1 = memref.load %arg0[%ind1] : memref<8xi32>
+  %2 = memref.load %arg0[%ind2] : memref<8xi32>
+  %3 = memref.load %arg0[%ind3] : memref<8xi32>
+
+  memref.store %0, %arg0[%ind0] : memref<8xi32>
+  memref.store %1, %arg0[%ind1] : memref<8xi32>
+  memref.store %2, %arg0[%ind2] : memref<8xi32>
+  memref.store %3, %arg0[%ind3] : memref<8xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32) {

>From c6021c52d52d23affe377ab80350a3aef019fc76 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 27 May 2025 13:12:33 +0200
Subject: [PATCH 45/52] fix offset

---
 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 8 +++++---
 mlir/test/Dialect/Vector/slp-vectorize.mlir          | 8 ++++----
 2 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 892a8807d70e4..81aa63b31bdc3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -965,13 +965,13 @@ SLPGraph::vectorize(IRRewriter &rewriter,
       }
     };
 
-    auto handleVecSizeMismatch = [&](Value arg) -> Value {
+    auto handleVecSizeMismatch = [&](Value arg, int64_t offset = 0) -> Value {
       auto srcType = cast<VectorType>(arg.getType());
       assert(srcType.getRank() == 1);
       if (srcType.getDimSize(0) == numElements)
         return arg;
 
-      return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, 0,
+      return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, offset,
                                                             numElements, 1);
     };
 
@@ -1007,7 +1007,9 @@ SLPGraph::vectorize(IRRewriter &rewriter,
       mapping.map(op->getResults(), newOp->getResults());
       handleNonVectorOutputs(newOp->getResult(0));
     } else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
-      Value val = handleVecSizeMismatch(extract.getVector());
+      // We alredy verified index is valid during graph construction.
+      int64_t offset = *getExtractIndex(extract);
+      Value val = handleVecSizeMismatch(extract.getVector(), offset);
       mapping.map(extract.getResult(), val);
     } else {
       op->emitError("unsupported operation");
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 4328926f8071f..e339eb5755bd6 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -380,9 +380,9 @@ func.func @read_read_add_write_attrs_mismatch(%arg0: memref<8xi32>, %arg1: memre
     // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
     // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
     // CHECK:     %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
-    // CHECK:     %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
     // CHECK:     %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
-    // CHECK:     %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
     // CHECK:     %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
     // CHECK:     %[[V5:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
     // CHECK:     %[[V6:.*]] = arith.addi %[[V4]], %[[V5]] overflow<nsw> : vector<2xi32>
@@ -670,9 +670,9 @@ func.func @different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
     // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
     // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
     // CHECK:     %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
-    // CHECK:     %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
     // CHECK:     %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
-    // CHECK:     %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+    // CHECK:     %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
     // CHECK:     cf.br ^bb1
     // CHECK:   ^bb1:
     // CHECK:     %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>

>From 74e1d80def3a0191e6e4b340b2d67ede682f4618 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 31 May 2025 21:14:51 +0200
Subject: [PATCH 46/52] support for 1-element vectors

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 126 +++++++++++++-----
 mlir/test/Dialect/Vector/slp-vectorize.mlir   |  74 ++++++++++
 2 files changed, 168 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 81aa63b31bdc3..a9411c7c903bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -68,6 +68,32 @@ static bool maybeWriteOp(Operation *op) {
   return effectInterface.hasEffect<MemoryEffects::Write>();
 }
 
+static Type getVectorElementType(VectorType vectorType) {
+  if (vectorType.getRank() > 1 || vectorType.isScalable() ||
+      vectorType.getNumElements() != 1)
+    return {};
+
+  return vectorType.getElementType();
+}
+
+static Type getElementType(Operation *op) {
+  assert(op && "null op");
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getResult().getType();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getValueToStore().getType();
+  if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+    return getVectorElementType(loadOp.getVectorType());
+  if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+    return getVectorElementType(storeOp.getVectorType());
+  return {};
+}
+
+static bool isSupportedMemOp(Operation *op) {
+  assert(op && "null op");
+  return isa_and_present<IntegerType, FloatType, IndexType>(getElementType(op));
+}
+
 /// Collect all memory operations in the block into groups.
 /// Each group contains either all loads or all stores, uninterrupted by
 /// operations of the other type.
@@ -85,7 +111,7 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
       }
     }
 
-    if (!isa<memref::LoadOp, memref::StoreOp>(op))
+    if (!isSupportedMemOp(&op))
       continue;
 
     bool isLoad = maybeReadOp(&op);
@@ -109,6 +135,19 @@ static Value getBase(Operation *op) {
     return loadOp.getMemRef();
   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
     return storeOp.getMemRef();
+  if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+    return loadOp.getBase();
+  if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+    return storeOp.getBase();
+  return {};
+}
+
+static Value getValueToStore(Operation *op) {
+  assert(op && "null op");
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getValueToStore();
+  if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+    return storeOp.getValueToStore();
   return {};
 }
 
@@ -131,15 +170,10 @@ static ValueRange getIndices(Operation *op) {
     return loadOp.getIndices();
   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
     return storeOp.getIndices();
-  return {};
-}
-
-static Type getElementType(Operation *op) {
-  assert(op && "null op");
-  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
-    return loadOp.getResult().getType();
-  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
-    return storeOp.getValueToStore().getType();
+  if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+    return loadOp.getIndices();
+  if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+    return storeOp.getIndices();
   return {};
 }
 
@@ -285,7 +319,15 @@ static bool isVectorizable(Operation *op) {
 
   for (auto type :
        llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
-    if (!type.isIntOrIndexOrFloat())
+    if (auto vectorType = dyn_cast<VectorType>(type)) {
+      if (vectorType.getRank() > 1 || vectorType.isScalable() ||
+          vectorType.getNumElements() != 1)
+        return false;
+
+      type = vectorType.getElementType();
+    }
+
+    if (!isa<IntegerType, FloatType, IndexType>(type))
       return false;
   }
 
@@ -464,8 +506,7 @@ class SLPGraph {
     for (const auto &node : nodes) {
       if (!node->isRoot)
         continue;
-      llvm::dbgs() << "  "
-                   << (isa<memref::LoadOp>(node->op()) ? "LOAD" : "STORE")
+      llvm::dbgs() << "  " << (maybeReadOp(node->op()) ? "LOAD" : "STORE")
                    << " group with " << node->size() << " operations:\n";
       for (auto *op : node->ops) {
         llvm::dbgs() << "    " << *op << "\n";
@@ -657,20 +698,36 @@ checkOpVecType(SLPGraphNode *node,
                llvm::function_ref<bool(Type, size_t)> isValidVecType) {
   Operation *op = node->op();
   size_t size = node->size();
-  if (Type elementType = getElementType(op))
-    return isValidVecType(elementType, size);
+  auto checkRes = [](bool res) -> bool {
+    LLVM_DEBUG(llvm::dbgs() << (res ? "true" : "false") << "\n");
+    return res;
+  };
+
+  if (Type elementType = getElementType(op)) {
+    LLVM_DEBUG(llvm::dbgs() << "Checking if type " << elementType
+                            << " with size " << size << " can be vectorized: ");
+    return checkRes(isValidVecType(elementType, size));
+  }
 
   if (isVectorizable(op)) {
     for (auto type :
          llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
-      if (!isValidVecType(type, size))
+      Type elementType = getElementTypeOrSelf(type);
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Checking if type " << elementType << " with size " << size
+                 << " can be vectorized: ");
+      if (!checkRes(isValidVecType(elementType, size)))
         return false;
     }
     return true;
   }
 
-  if (auto extract = dyn_cast<vector::ExtractOp>(op))
-    return isValidVecType(extract.getResult().getType(), size);
+  if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
+    Type type = extract.getResult().getType();
+    LLVM_DEBUG(llvm::dbgs() << "Checking if type " << type << " with size "
+                            << size << " can be vectorized: ");
+    return checkRes(isValidVecType(type, size));
+  }
 
   LLVM_DEBUG(llvm::dbgs() << "Unsupported op " << op->getName() << "\n");
   return false;
@@ -903,12 +960,19 @@ SLPGraph::vectorize(IRRewriter &rewriter,
     for (auto *operand : node->operands)
       size = std::min(size, operand->size());
 
-    node->ops.resize(size);
+    if (size < node->size()) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Size mismatch, resizing node with " << node->size()
+                 << " operations to " << size << "\n");
+      node->ops.resize(size);
+    }
 
     while (node->size() > 1) {
       if (checkOpVecType(node, isValidVecType))
         break;
 
+      LLVM_DEBUG(llvm::dbgs() << "No a valid vector type, popping back op: "
+                              << node->ops.back()->getName() << "\n");
       node->ops.pop_back();
     }
   }
@@ -975,24 +1039,22 @@ SLPGraph::vectorize(IRRewriter &rewriter,
                                                             numElements, 1);
     };
 
-    if (auto load = dyn_cast<memref::LoadOp>(op)) {
-      auto vecType =
-          VectorType::get(numElements, load.getMemRefType().getElementType());
-      Value result = rewriter.create<vector::LoadOp>(
-          loc, vecType, load.getMemRef(), load.getIndices());
-      mapping.map(load.getResult(), result);
+    if (maybeReadOp(op)) {
+      auto vecType = VectorType::get(numElements, getElementType(op));
+      Value result = rewriter.create<vector::LoadOp>(loc, vecType, getBase(op),
+                                                     getIndices(op));
+      mapping.map(op->getResult(0), result);
       handleNonVectorOutputs(result);
-    } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
-      handleNonVectorInputs(store.getValueToStore());
-      Value val = mapping.lookupOrDefault(store.getValueToStore());
+    } else if (maybeWriteOp(op)) {
+      handleNonVectorInputs(getValueToStore(op));
+      Value val = mapping.lookupOrDefault(getValueToStore(op));
       val = handleVecSizeMismatch(val);
-      rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
-                                       store.getIndices());
+      rewriter.create<vector::StoreOp>(loc, val, getBase(op), getIndices(op));
     } else if (isVectorizable(op)) {
       handleNonVectorInputs(op->getOperands());
       Operation *newOp = rewriter.clone(*op, mapping);
-      auto resVectorType =
-          VectorType::get(numElements, op->getResultTypes().front());
+      Type resType = getElementTypeOrSelf(op->getResultTypes().front());
+      auto resVectorType = VectorType::get(numElements, resType);
 
       {
         OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index e339eb5755bd6..aeedececa1a7c 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -276,6 +276,80 @@ func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
 }
 
 
+// CHECK-LABEL: func @read_read_add_write_vec_0d
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_vec_0d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<i32>
+  %1 = vector.load %arg0[%c1] : memref<8xi32>, vector<i32>
+  %2 = vector.load %arg0[%c2] : memref<8xi32>, vector<i32>
+  %3 = vector.load %arg0[%c3] : memref<8xi32>, vector<i32>
+
+  %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<i32>
+  %5 = vector.load %arg1[%c1] : memref<8xi32>, vector<i32>
+  %6 = vector.load %arg1[%c2] : memref<8xi32>, vector<i32>
+  %7 = vector.load %arg1[%c3] : memref<8xi32>, vector<i32>
+
+  %8 = arith.addi %0, %4 : vector<i32>
+  %9 = arith.addi %1, %5 : vector<i32>
+  %10 = arith.addi %2, %6 : vector<i32>
+  %11 = arith.addi %3, %7 : vector<i32>
+
+  vector.store %8, %arg0[%c0] : memref<8xi32>, vector<i32>
+  vector.store %9, %arg0[%c1] : memref<8xi32>, vector<i32>
+  vector.store %10, %arg0[%c2] : memref<8xi32>, vector<i32>
+  vector.store %11, %arg0[%c3] : memref<8xi32>, vector<i32>
+
+  return
+}
+
+
+// CHECK-LABEL: func @read_read_add_write_vec_1d
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_vec_1d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<1xi32>
+  %1 = vector.load %arg0[%c1] : memref<8xi32>, vector<1xi32>
+  %2 = vector.load %arg0[%c2] : memref<8xi32>, vector<1xi32>
+  %3 = vector.load %arg0[%c3] : memref<8xi32>, vector<1xi32>
+
+  %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<1xi32>
+  %5 = vector.load %arg1[%c1] : memref<8xi32>, vector<1xi32>
+  %6 = vector.load %arg1[%c2] : memref<8xi32>, vector<1xi32>
+  %7 = vector.load %arg1[%c3] : memref<8xi32>, vector<1xi32>
+
+  %8 = arith.addi %0, %4 : vector<1xi32>
+  %9 = arith.addi %1, %5 : vector<1xi32>
+  %10 = arith.addi %2, %6 : vector<1xi32>
+  %11 = arith.addi %3, %7 : vector<1xi32>
+
+  vector.store %8, %arg0[%c0] : memref<8xi32>, vector<1xi32>
+  vector.store %9, %arg0[%c1] : memref<8xi32>, vector<1xi32>
+  vector.store %10, %arg0[%c2] : memref<8xi32>, vector<1xi32>
+  vector.store %11, %arg0[%c3] : memref<8xi32>, vector<1xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add_write_seven
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xindex>, %[[ARG1:.*]]: memref<8xindex>)
 func.func @read_read_add_write_seven(%arg0: memref<8xindex>, %arg1: memref<8xindex>) {

>From 5e473ab19148652d7688a43a51447e3880a239e9 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 31 May 2025 21:23:44 +0200
Subject: [PATCH 47/52] refac size() -> opsCount()

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 45 ++++++++++---------
 1 file changed, 24 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index a9411c7c903bd..e2b3156bd8f6c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -49,7 +49,7 @@ struct MemoryOpGroup {
   bool isLoadGroup() const { return type == Type::Load; }
   bool isStoreGroup() const { return type == Type::Store; }
 
-  size_t size() const { return ops.size(); }
+  size_t opsCount() const { return ops.size(); }
 };
 
 static bool maybeReadOp(Operation *op) {
@@ -305,7 +305,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
     }
 
     LLVM_DEBUG(llvm::dbgs() << "Extracted contiguous group with "
-                            << currentGroup.size() << " operations\n");
+                            << currentGroup.opsCount() << " operations\n");
   }
   return result;
 }
@@ -353,7 +353,7 @@ struct SLPGraphNode {
   SLPGraphNode(ArrayRef<Operation *> operations)
       : ops(operations.begin(), operations.end()) {}
 
-  size_t size() const { return ops.size(); }
+  size_t opsCount() const { return ops.size(); }
 
   Operation *op() const {
     assert(!ops.empty() && "empty ops");
@@ -507,13 +507,14 @@ class SLPGraph {
       if (!node->isRoot)
         continue;
       llvm::dbgs() << "  " << (maybeReadOp(node->op()) ? "LOAD" : "STORE")
-                   << " group with " << node->size() << " operations:\n";
+                   << " group with " << node->opsCount() << " operations:\n";
       for (auto *op : node->ops) {
         llvm::dbgs() << "    " << *op << "\n";
       }
       llvm::dbgs() << "    Users: ";
       for (auto *user : node->users) {
-        llvm::dbgs() << "\n      Group with " << user->size() << " operations:";
+        llvm::dbgs() << "\n      Group with " << user->opsCount()
+                     << " operations:";
         for (auto *op : user->ops) {
           llvm::dbgs() << "\n        " << *op;
         }
@@ -526,13 +527,13 @@ class SLPGraph {
     for (const auto &node : nodes) {
       if (node->isRoot)
         continue;
-      llvm::dbgs() << "  Group with " << node->size() << " operations:\n";
+      llvm::dbgs() << "  Group with " << node->opsCount() << " operations:\n";
       for (auto *op : node->ops) {
         llvm::dbgs() << "    " << *op << "\n";
       }
       llvm::dbgs() << "    Operands: ";
       for (auto *operand : node->operands) {
-        llvm::dbgs() << "\n      Group with " << operand->size()
+        llvm::dbgs() << "\n      Group with " << operand->opsCount()
                      << " operations:";
         for (auto *op : operand->ops) {
           llvm::dbgs() << "\n        " << *op;
@@ -540,7 +541,8 @@ class SLPGraph {
       }
       llvm::dbgs() << "\n    Users: ";
       for (auto *user : node->users) {
-        llvm::dbgs() << "\n      Group with " << user->size() << " operations:";
+        llvm::dbgs() << "\n      Group with " << user->opsCount()
+                     << " operations:";
         for (auto *op : user->ops) {
           llvm::dbgs() << "\n        " << *op;
         }
@@ -697,7 +699,7 @@ static bool
 checkOpVecType(SLPGraphNode *node,
                llvm::function_ref<bool(Type, size_t)> isValidVecType) {
   Operation *op = node->op();
-  size_t size = node->size();
+  size_t size = node->opsCount();
   auto checkRes = [](bool res) -> bool {
     LLVM_DEBUG(llvm::dbgs() << (res ? "true" : "false") << "\n");
     return res;
@@ -779,7 +781,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     worklist.push_back(node);
 
     LLVM_DEBUG({
-      llvm::dbgs() << "Created root group node with " << node->size()
+      llvm::dbgs() << "Created root group node with " << node->opsCount()
                    << " operations of type "
                    << (group.isLoadGroup() ? "Load" : "Store") << "\n";
     });
@@ -907,7 +909,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
   while (!worklist.empty()) {
     SLPGraphNode *node = worklist.pop_back_val();
     LLVM_DEBUG(llvm::dbgs()
-               << "Processing node with " << node->size()
+               << "Processing node with " << node->opsCount()
                << " operations, first op: " << node->op()->getName() << "\n");
 
     Operation *op = node->op();
@@ -940,7 +942,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
   LLVM_DEBUG({
     llvm::dbgs() << "Topologically sorted nodes:\n";
     for (auto *node : sortedNodes) {
-      llvm::dbgs() << "  Node with " << node->size()
+      llvm::dbgs() << "  Node with " << node->opsCount()
                    << " operations: " << node->op()->getName() << "\n";
     }
   });
@@ -948,7 +950,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
   auto isBadNode = [&](SLPGraphNode *node) {
     // Do not vectorize stray nodes which are not connected to any other
     // nodes.
-    return (node->users.empty() && node->operands.empty()) || node->size() <= 1;
+    return (node->users.empty() && node->operands.empty()) ||
+           node->opsCount() <= 1;
   };
 
   // Update node vec sizes if its inputs vec sizes are smaller.
@@ -956,18 +959,18 @@ SLPGraph::vectorize(IRRewriter &rewriter,
   // TODO: It maybe possible to reconstruct the larger vec size combining src
   // smaller vector and scalar arg.
   for (auto *node : sortedNodes) {
-    size_t size = node->size();
+    size_t size = node->opsCount();
     for (auto *operand : node->operands)
-      size = std::min(size, operand->size());
+      size = std::min(size, operand->opsCount());
 
-    if (size < node->size()) {
+    if (size < node->opsCount()) {
       LLVM_DEBUG(llvm::dbgs()
-                 << "Size mismatch, resizing node with " << node->size()
+                 << "Size mismatch, resizing node with " << node->opsCount()
                  << " operations to " << size << "\n");
       node->ops.resize(size);
     }
 
-    while (node->size() > 1) {
+    while (node->opsCount() > 1) {
       if (checkOpVecType(node, isValidVecType))
         break;
 
@@ -982,7 +985,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
   IRMapping mapping;
   for (auto *node : sortedNodes) {
     LLVM_DEBUG({
-      llvm::dbgs() << "Processing node with " << node->size()
+      llvm::dbgs() << "Processing node with " << node->opsCount()
                    << " operations\n";
       llvm::dbgs() << "  First op: " << *node->op() << "\n";
     });
@@ -997,7 +1000,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
     LLVM_DEBUG(llvm::dbgs() << "  Insertion point: " << *ip << "\n");
 
     rewriter.setInsertionPoint(ip);
-    int64_t numElements = node->size();
+    int64_t numElements = node->opsCount();
     Location loc = op->getLoc();
 
     auto handleNonVectorInputs = [&](ValueRange operands) {
@@ -1115,7 +1118,7 @@ tryToVectorizeInBlock(Block &block,
                    << " contiguous groups in "
                    << (group.isLoadGroup() ? "load" : "store") << " group\n";
       for (const auto &contigGroup : contiguousGroups) {
-        llvm::dbgs() << "  Contiguous group with " << contigGroup.size()
+        llvm::dbgs() << "  Contiguous group with " << contigGroup.opsCount()
                      << " operations\n";
       }
     });

>From 4161f5a63af8e950ba8150eafeb4439ecf5d8bb3 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 31 May 2025 22:37:50 +0200
Subject: [PATCH 48/52] merge vectorized ops too

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 129 ++++++++++++------
 mlir/test/Dialect/Vector/slp-vectorize.mlir   |  34 ++---
 2 files changed, 105 insertions(+), 58 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e2b3156bd8f6c..c010193226814 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -43,6 +43,7 @@ struct MemoryOpGroup {
   enum class Type { Load, Store };
   Type type;
   SmallVector<Operation *> ops;
+  int64_t elementsCount = 0;
 
   MemoryOpGroup(Type t) : type(t) {}
 
@@ -68,30 +69,37 @@ static bool maybeWriteOp(Operation *op) {
   return effectInterface.hasEffect<MemoryEffects::Write>();
 }
 
-static Type getVectorElementType(VectorType vectorType) {
-  if (vectorType.getRank() > 1 || vectorType.isScalable() ||
-      vectorType.getNumElements() != 1)
-    return {};
+static std::optional<std::pair<Type, int64_t>>
+getVectorElementTypeAndCount(VectorType vectorType) {
+  if (vectorType.getRank() > 1 || vectorType.isScalable())
+    return std::nullopt;
 
-  return vectorType.getElementType();
+  return std::make_pair(vectorType.getElementType(),
+                        vectorType.getNumElements());
 }
 
-static Type getElementType(Operation *op) {
+static std::optional<std::pair<Type, int64_t>>
+getElementTypeAndCount(Operation *op) {
   assert(op && "null op");
   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
-    return loadOp.getResult().getType();
+    return std::make_pair(loadOp.getResult().getType(), 1);
   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
-    return storeOp.getValueToStore().getType();
+    return std::make_pair(storeOp.getValueToStore().getType(), 1);
   if (auto loadOp = dyn_cast<vector::LoadOp>(op))
-    return getVectorElementType(loadOp.getVectorType());
+    return getVectorElementTypeAndCount(loadOp.getVectorType());
   if (auto storeOp = dyn_cast<vector::StoreOp>(op))
-    return getVectorElementType(storeOp.getVectorType());
-  return {};
+    return getVectorElementTypeAndCount(storeOp.getVectorType());
+  return std::nullopt;
 }
 
 static bool isSupportedMemOp(Operation *op) {
   assert(op && "null op");
-  return isa_and_present<IntegerType, FloatType, IndexType>(getElementType(op));
+  auto typeAndCount = getElementTypeAndCount(op);
+  if (!typeAndCount)
+    return false;
+
+  return isa_and_present<IntegerType, FloatType, IndexType>(
+      typeAndCount->first);
 }
 
 /// Collect all memory operations in the block into groups.
@@ -177,7 +185,7 @@ static ValueRange getIndices(Operation *op) {
   return {};
 }
 
-static bool isAdjacentAffineMapIndices(Value idx1, Value idx2) {
+static bool isAdjacentAffineMapIndices(Value idx1, Value idx2, int64_t offset) {
   auto applyOp1 = idx1.getDefiningOp<affine::AffineApplyOp>();
   if (!applyOp1)
     return false;
@@ -195,28 +203,29 @@ static bool isAdjacentAffineMapIndices(Value idx1, Value idx2) {
       simplifyAffineExpr(expr2 - expr1, 0, applyOp1.getOperands().size());
 
   auto diffConst = dyn_cast<AffineConstantExpr>(diff);
-  return diffConst && diffConst.getValue() == 1;
+  return diffConst && diffConst.getValue() == offset;
 }
 
 /// Check if two indices are consecutive, i.e index1 + 1 == index2.
-static bool isAdjacentIndices(Value idx1, Value idx2) {
+static bool isAdjacentIndices(Value idx1, Value idx2, int64_t offset) {
   if (auto c1 = getConstantIntValue(idx1)) {
     if (auto c2 = getConstantIntValue(idx2))
-      return *c1 + 1 == *c2;
+      return *c1 + offset == *c2;
   }
 
   if (auto addOp2 = idx2.getDefiningOp<arith::AddIOp>()) {
-    if (addOp2.getLhs() == idx1 && getConstantIntValue(addOp2.getRhs()) == 1)
+    if (addOp2.getLhs() == idx1 &&
+        getConstantIntValue(addOp2.getRhs()) == offset)
       return true;
 
     if (auto addOp1 = idx1.getDefiningOp<arith::AddIOp>()) {
       if (addOp1.getLhs() == addOp2.getLhs() &&
-          isAdjacentIndices(addOp1.getRhs(), addOp2.getRhs()))
+          isAdjacentIndices(addOp1.getRhs(), addOp2.getRhs(), offset))
         return true;
     }
   }
 
-  if (isAdjacentAffineMapIndices(idx1, idx2))
+  if (isAdjacentAffineMapIndices(idx1, idx2, offset))
     return true;
 
   return false;
@@ -224,19 +233,22 @@ static bool isAdjacentIndices(Value idx1, Value idx2) {
 
 /// Check if two ranges of indices are consecutive, i.e fastest index differs
 /// by 1 and all other indices are the same.
-static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
+static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2,
+                              int64_t offset) {
   if (idx1.empty() || idx1.size() != idx2.size())
     return false;
 
   if (idx1.drop_back() != idx2.drop_back())
     return false;
 
-  return isAdjacentIndices(idx1.back(), idx2.back());
+  return isAdjacentIndices(idx1.back(), idx2.back(), offset);
 }
 
 /// Check if two operations are adjacent and can be combined into a vector op.
 /// This is done by checking if the base memrefs are the same, the last
-/// dimension is contiguous, and the element types and indices are compatible
+/// dimension is contiguous, and the element types and indices are compatible.
+/// If source read/write is already vectorized, only merge ops if vector
+/// elements count is the same.
 static bool isAdjacentOps(Operation *op1, Operation *op2) {
   assert(op1 && "null op1");
   assert(op2 && "null op2");
@@ -249,10 +261,19 @@ static bool isAdjacentOps(Operation *op1, Operation *op2) {
   if (!isContiguousLastDim(base1))
     return false;
 
-  if (getElementType(op1) != getElementType(op2))
+  auto typeAndCount1 = getElementTypeAndCount(op1);
+  if (!typeAndCount1)
+    return false;
+
+  auto typeAndCount2 = getElementTypeAndCount(op2);
+  if (!typeAndCount2)
     return false;
 
-  return isAdjacentIndices(getIndices(op1), getIndices(op2));
+  if (typeAndCount1 != typeAndCount2)
+    return false;
+
+  return isAdjacentIndices(getIndices(op1), getIndices(op2),
+                           typeAndCount1->second);
 }
 
 // Extract contiguous groups from a MemoryOpGroup
@@ -271,6 +292,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
     // Start a new group with this operation
     result.emplace_back(group.type);
     MemoryOpGroup &currentGroup = result.back();
+    currentGroup.elementsCount = getElementTypeAndCount(op)->second;
     auto &currentOps = currentGroup.ops;
     currentOps.push_back(op);
     processedOps.insert(op);
@@ -310,7 +332,9 @@ extractContiguousGroups(const MemoryOpGroup &group) {
   return result;
 }
 
-static bool isVectorizable(Operation *op) {
+static bool
+isVectorizable(Operation *op,
+               std::optional<int64_t> expectedElementsCount = std::nullopt) {
   if (!OpTrait::hasElementwiseMappableTraits(op))
     return false;
 
@@ -319,14 +343,18 @@ static bool isVectorizable(Operation *op) {
 
   for (auto type :
        llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
+    int64_t vectorElementsCount = 1;
     if (auto vectorType = dyn_cast<VectorType>(type)) {
-      if (vectorType.getRank() > 1 || vectorType.isScalable() ||
-          vectorType.getNumElements() != 1)
+      if (vectorType.getRank() > 1 || vectorType.isScalable())
         return false;
 
       type = vectorType.getElementType();
+      vectorElementsCount = vectorType.getNumElements();
     }
 
+    if (expectedElementsCount && vectorElementsCount != *expectedElementsCount)
+      return false;
+
     if (!isa<IntegerType, FloatType, IndexType>(type))
       return false;
   }
@@ -347,6 +375,7 @@ struct SLPGraphNode {
   SmallVector<SLPGraphNode *> users;
   SmallVector<SLPGraphNode *> operands;
   Operation *insertionPoint = nullptr;
+  int64_t elementsCount = 0;
   bool isRoot = false;
 
   SLPGraphNode() = default;
@@ -354,6 +383,7 @@ struct SLPGraphNode {
       : ops(operations.begin(), operations.end()) {}
 
   size_t opsCount() const { return ops.size(); }
+  size_t vectorSize() const { return elementsCount * opsCount(); }
 
   Operation *op() const {
     assert(!ops.empty() && "empty ops");
@@ -415,17 +445,20 @@ class SLPGraph {
   SLPGraph &operator=(SLPGraph &&) = default;
 
   /// Add a new node to the graph
-  SLPGraphNode *addNode(ArrayRef<Operation *> operations) {
+  SLPGraphNode *addNode(ArrayRef<Operation *> operations,
+                        int64_t elementsCount) {
     nodes.push_back(std::make_unique<SLPGraphNode>(operations));
     auto *node = nodes.back().get();
+    node->elementsCount = elementsCount;
     for (Operation *op : operations)
       opToNode[op] = node;
     return node;
   }
 
   /// Add a root node (memory operation)
-  SLPGraphNode *addRoot(ArrayRef<Operation *> operations) {
-    auto *node = addNode(operations);
+  SLPGraphNode *addRoot(ArrayRef<Operation *> operations,
+                        int64_t elementsCount) {
+    auto *node = addNode(operations, elementsCount);
     node->isRoot = true;
     return node;
   }
@@ -699,13 +732,14 @@ static bool
 checkOpVecType(SLPGraphNode *node,
                llvm::function_ref<bool(Type, size_t)> isValidVecType) {
   Operation *op = node->op();
-  size_t size = node->opsCount();
+  size_t size = node->vectorSize();
   auto checkRes = [](bool res) -> bool {
     LLVM_DEBUG(llvm::dbgs() << (res ? "true" : "false") << "\n");
     return res;
   };
 
-  if (Type elementType = getElementType(op)) {
+  if (auto typeAndCount = getElementTypeAndCount(op)) {
+    Type elementType = typeAndCount->first;
     LLVM_DEBUG(llvm::dbgs() << "Checking if type " << elementType
                             << " with size " << size << " can be vectorized: ");
     return checkRes(isValidVecType(elementType, size));
@@ -777,7 +811,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
 
   // First, create nodes for each contiguous memory operation group
   for (const auto &group : rootGroups) {
-    auto *node = graph.addRoot(group.ops);
+    auto *node = graph.addRoot(group.ops, group.elementsCount);
     worklist.push_back(node);
 
     LLVM_DEBUG({
@@ -800,7 +834,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
       return;
     }
 
-    if (!isVectorizable(user))
+    if (!isVectorizable(user, node->elementsCount))
       return;
 
     Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
@@ -830,7 +864,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     if (currentOps.size() == 1)
       return;
 
-    auto *newNode = graph.addNode(currentOps);
+    auto *newNode = graph.addNode(currentOps, node->elementsCount);
     graph.addEdge(node, newNode);
     for (Operation *op : currentOps)
       fingerprints.invalidate(op);
@@ -877,7 +911,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
         currentOps.push_back(otherOp);
         ++currentIndex;
       }
-    } else if (isVectorizable(srcOp)) {
+    } else if (isVectorizable(srcOp, node->elementsCount)) {
       LLVM_DEBUG(llvm::dbgs() << "  Processing vectorizable op "
                               << srcOp->getName() << "\n");
 
@@ -898,7 +932,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     if (currentOps.size() == 1)
       return;
 
-    auto *newNode = graph.addNode(currentOps);
+    auto *newNode = graph.addNode(currentOps, node->elementsCount);
     graph.addEdge(newNode, node);
     for (Operation *op : currentOps)
       fingerprints.invalidate(op);
@@ -1000,7 +1034,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
     LLVM_DEBUG(llvm::dbgs() << "  Insertion point: " << *ip << "\n");
 
     rewriter.setInsertionPoint(ip);
-    int64_t numElements = node->opsCount();
+    int64_t numElements = node->vectorSize();
     Location loc = op->getLoc();
 
     auto handleNonVectorInputs = [&](ValueRange operands) {
@@ -1009,10 +1043,20 @@ SLPGraph::vectorize(IRRewriter &rewriter,
           continue;
 
         SmallVector<Value> args;
-        for (Operation *defOp : node->ops)
-          args.push_back(defOp->getOperand(i));
+        for (Operation *defOp : node->ops) {
+          Value arg = defOp->getOperand(i);
+          if (auto vecType = dyn_cast<VectorType>(arg.getType())) {
+            assert(vecType.getRank() == 1);
+            for (auto j : llvm::seq(vecType.getNumElements()))
+              args.push_back(rewriter.create<vector::ExtractOp>(loc, arg, j));
+
+          } else {
+            args.push_back(arg);
+          }
+        }
 
-        auto vecType = VectorType::get(numElements, operand.getType());
+        auto vecType = VectorType::get(numElements,
+                                       getElementTypeOrSelf(operand.getType()));
         Value vector =
             rewriter.create<vector::FromElementsOp>(loc, vecType, args);
         mapping.map(operand, vector);
@@ -1043,7 +1087,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
     };
 
     if (maybeReadOp(op)) {
-      auto vecType = VectorType::get(numElements, getElementType(op));
+      auto vecType =
+          VectorType::get(numElements, getElementTypeAndCount(op)->first);
       Value result = rewriter.create<vector::LoadOp>(loc, vecType, getBase(op),
                                                      getIndices(op));
       mapping.map(op->getResult(0), result);
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index aeedececa1a7c..598ba5c755ab1 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -646,22 +646,24 @@ func.func private @use(i32)
 // CHECK-LABEL: func @read_read_add_write_interleaved_use
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
-  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
-  // CHECK:     %[[C2:.*]] = arith.constant 2 : index
-  // CHECK:     %[[C3:.*]] = arith.constant 3 : index
-  // CHECK:     %[[V0:.*]] = memref.load %arg0[%[[C3]]] : memref<8xi32>
-  // CHECK:     %[[V1:.*]] = memref.load %arg1[%[[C3]]] : memref<8xi32>
-  // CHECK:     call @use(%[[V0]]) : (i32) -> ()
-  // CHECK:     %[[V2:.*]] = vector.load %arg0[%[[C0]]] : memref<8xi32>, vector<2xi32>
-  // CHECK:     %[[V3:.*]] = vector.load %arg1[%[[C0]]] : memref<8xi32>, vector<2xi32>
-  // CHECK:     %[[V4:.*]] = memref.load %arg0[%[[C2]]] : memref<8xi32>
-  // CHECK:     %[[V5:.*]] = memref.load %arg1[%[[C2]]] : memref<8xi32>
-  // CHECK:     %[[V6:.*]] = vector.from_elements %[[V4]], %[[V0]] : vector<2xi32>
-  // CHECK:     %[[V7:.*]] = vector.from_elements %[[V5]], %[[V1]] : vector<2xi32>
-  // CHECK:     %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] : vector<2xi32>
-  // CHECK:     %[[V9:.*]] = arith.addi %[[V2]], %[[V3]] : vector<2xi32>
-  // CHECK:     vector.store %[[V9]], %arg0[%[[C0]]] : memref<8xi32>, vector<2xi32>
-  // CHECK:     vector.store %[[V8]], %arg0[%[[C2]]] : memref<8xi32>, vector<2xi32>
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+  // CHECK: %[[V0:.*]] = memref.load %[[ARG0]][%[[C3]]] : memref<8xi32>
+  // CHECK: %[[V1:.*]] = memref.load %[[ARG1]][%[[C3]]] : memref<8xi32>
+  // CHECK: call @use(%[[V0]]) : (i32) -> ()
+  // CHECK: %[[V2:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+  // CHECK: %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+  // CHECK: %[[V4:.*]] = memref.load %[[ARG0]][%[[C2]]] : memref<8xi32>
+  // CHECK: %[[V5:.*]] = memref.load %[[ARG1]][%[[C2]]] : memref<8xi32>
+  // CHECK: %[[V6:.*]] = vector.extract %[[V2]][0] : i32 from vector<2xi32>
+  // CHECK: %[[V7:.*]] = vector.extract %[[V2]][1] : i32 from vector<2xi32>
+  // CHECK: %[[V8:.*]] = vector.from_elements %[[V6]], %[[V7]], %[[V4]], %[[V0]] : vector<4xi32>
+  // CHECK: %[[V9:.*]] = vector.extract %[[V3]][0] : i32 from vector<2xi32>
+  // CHECK: %[[V10:.*]] = vector.extract %[[V3]][1] : i32 from vector<2xi32>
+  // CHECK: %[[V11:.*]] = vector.from_elements %[[V9]], %[[V10]], %[[V5]], %[[V1]] : vector<4xi32>
+  // CHECK: %[[V12:.*]] = arith.addi %[[V8]], %[[V11]] : vector<4xi32>
+  // CHECK: vector.store %[[V12]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
 
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index

>From 0e557a7e5965d2c7b11f65ffd88852828b8ca29e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 1 Jun 2025 12:05:59 +0200
Subject: [PATCH 49/52] more tests

---
 mlir/test/Dialect/Vector/slp-vectorize.mlir | 32 +++++++++++++++++++++
 1 file changed, 32 insertions(+)

diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 598ba5c755ab1..a6108287551f4 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -350,6 +350,38 @@ func.func @read_read_add_write_vec_1d(%arg0: memref<8xi32>, %arg1: memref<8xi32>
 }
 
 
+// CHECK-LABEL: func @read_read_add_write_mixed_vecs
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_mixed_vecs(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<2xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = vector.load %arg0[%c3] : memref<8xi32>, vector<1xi32>
+
+  %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<2xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = vector.load %arg1[%c3] : memref<8xi32>, vector<1xi32>
+
+  %8 = arith.addi %0, %4 : vector<2xi32>
+  %10 = arith.addi %2, %6 : i32
+  %11 = arith.addi %3, %7 : vector<1xi32>
+
+  vector.store %8, %arg0[%c0] : memref<8xi32>, vector<2xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  vector.store %11, %arg0[%c3] : memref<8xi32>, vector<1xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add_write_seven
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xindex>, %[[ARG1:.*]]: memref<8xindex>)
 func.func @read_read_add_write_seven(%arg0: memref<8xindex>, %arg1: memref<8xindex>) {

>From 8e2d6118a7d82c1106088ac24f5d745bff0958cc Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 1 Jun 2025 12:35:38 +0200
Subject: [PATCH 50/52] comments

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 33 +++++++++++++++----
 1 file changed, 26 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index c010193226814..3af4b9dfb4ce4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -89,6 +89,7 @@ getElementTypeAndCount(Operation *op) {
     return getVectorElementTypeAndCount(loadOp.getVectorType());
   if (auto storeOp = dyn_cast<vector::StoreOp>(op))
     return getVectorElementTypeAndCount(storeOp.getVectorType());
+
   return std::nullopt;
 }
 
@@ -147,7 +148,8 @@ static Value getBase(Operation *op) {
     return loadOp.getBase();
   if (auto storeOp = dyn_cast<vector::StoreOp>(op))
     return storeOp.getBase();
-  return {};
+
+  llvm_unreachable("unsupported op");
 }
 
 static Value getValueToStore(Operation *op) {
@@ -156,7 +158,8 @@ static Value getValueToStore(Operation *op) {
     return storeOp.getValueToStore();
   if (auto storeOp = dyn_cast<vector::StoreOp>(op))
     return storeOp.getValueToStore();
-  return {};
+
+  llvm_unreachable("unsupported op");
 }
 
 static bool isContiguousLastDim(Value val) {
@@ -182,7 +185,8 @@ static ValueRange getIndices(Operation *op) {
     return loadOp.getIndices();
   if (auto storeOp = dyn_cast<vector::StoreOp>(op))
     return storeOp.getIndices();
-  return {};
+
+  llvm_unreachable("unsupported op");
 }
 
 static bool isAdjacentAffineMapIndices(Value idx1, Value idx2, int64_t offset) {
@@ -206,7 +210,7 @@ static bool isAdjacentAffineMapIndices(Value idx1, Value idx2, int64_t offset) {
   return diffConst && diffConst.getValue() == offset;
 }
 
-/// Check if two indices are consecutive, i.e index1 + 1 == index2.
+/// Check if two indices are consecutive, i.e index1 + offset == index2.
 static bool isAdjacentIndices(Value idx1, Value idx2, int64_t offset) {
   if (auto c1 = getConstantIntValue(idx1)) {
     if (auto c2 = getConstantIntValue(idx2))
@@ -232,7 +236,7 @@ static bool isAdjacentIndices(Value idx1, Value idx2, int64_t offset) {
 }
 
 /// Check if two ranges of indices are consecutive, i.e fastest index differs
-/// by 1 and all other indices are the same.
+/// by `offset` and all other indices are the same.
 static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2,
                               int64_t offset) {
   if (idx1.empty() || idx1.size() != idx2.size())
@@ -272,6 +276,7 @@ static bool isAdjacentOps(Operation *op1, Operation *op2) {
   if (typeAndCount1 != typeAndCount2)
     return false;
 
+  // For now we are only merging ops with same elements count.
   return isAdjacentIndices(getIndices(op1), getIndices(op2),
                            typeAndCount1->second);
 }
@@ -332,6 +337,9 @@ extractContiguousGroups(const MemoryOpGroup &group) {
   return result;
 }
 
+/// Check if an operation is vectorizable.
+/// If `expectedElementsCount` is provided, check if original op had the
+/// specified number of elements.
 static bool
 isVectorizable(Operation *op,
                std::optional<int64_t> expectedElementsCount = std::nullopt) {
@@ -362,7 +370,8 @@ isVectorizable(Operation *op,
   return true;
 }
 
-/// Get the next operation in the block, assuming `op` is not a terminator.
+/// Get the next operation in the block, assuming `op` is not a terminator/last
+/// operation in the block.
 static Operation *nextOp(Operation *op) {
   assert(op && "null op");
   auto it = op->getIterator();
@@ -390,6 +399,9 @@ struct SLPGraphNode {
     return ops.front();
   }
 
+  /// Get the suitable insertion point for the new vectorized op.
+  /// This method is trying to take into account operands insertions points too
+  /// to satisfy dominance relations.
   Operation *getInsertionPoint() {
     assert(!ops.empty() && "empty node");
     if (insertionPoint)
@@ -1038,6 +1050,9 @@ SLPGraph::vectorize(IRRewriter &rewriter,
     Location loc = op->getLoc();
 
     auto handleNonVectorInputs = [&](ValueRange operands) {
+      // Handle the case when op operands are not vectorized or have smaller
+      // vector size, construct the vector from the scalar operands using
+      // FromElementsOp.
       for (auto [i, operand] : llvm::enumerate(operands)) {
         if (getNodeForOp(operand.getDefiningOp()))
           continue;
@@ -1064,6 +1079,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
     };
 
     auto handleNonVectorOutputs = [&](Value newResult) {
+      // Handle the case when op results are not vectorized or have smaller
+      // vector size, extract the elements from the vector.
       for (auto [i, result] : llvm::enumerate(node->ops)) {
         for (OpOperand &use : result->getUses()) {
           Operation *useOwner = use.getOwner();
@@ -1077,6 +1094,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
     };
 
     auto handleVecSizeMismatch = [&](Value arg, int64_t offset = 0) -> Value {
+      // Handle vector size misamatch between 2 vectorized nodes.
       auto srcType = cast<VectorType>(arg.getType());
       assert(srcType.getRank() == 1);
       if (srcType.getDimSize(0) == numElements)
@@ -1117,7 +1135,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
       mapping.map(op->getResults(), newOp->getResults());
       handleNonVectorOutputs(newOp->getResult(0));
     } else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
-      // We alredy verified index is valid during graph construction.
+      // We alredy verified index is valid during graph construction, so
+      // do need to check `getExtractIndex` result.
       int64_t offset = *getExtractIndex(extract);
       Value val = handleVecSizeMismatch(extract.getVector(), offset);
       mapping.map(extract.getResult(), val);

>From 04cc9219d07e5f82be86f228d22f5ba4925b9f6d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 1 Jun 2025 12:48:06 +0200
Subject: [PATCH 51/52] vector outputs handling

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 21 +++++++++++----
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 26 +++++++++++++++++++
 2 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3af4b9dfb4ce4..9a389612567df 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -1078,7 +1078,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
       }
     };
 
-    auto handleNonVectorOutputs = [&](Value newResult) {
+    auto handleNonVectorOutputs = [&](Value newResult,
+                                      Type originalResultType) {
       // Handle the case when op results are not vectorized or have smaller
       // vector size, extract the elements from the vector.
       for (auto [i, result] : llvm::enumerate(node->ops)) {
@@ -1087,7 +1088,16 @@ SLPGraph::vectorize(IRRewriter &rewriter,
           if (getNodeForOp(useOwner))
             continue;
 
-          Value elem = rewriter.create<vector::ExtractOp>(loc, newResult, i);
+          int64_t offset = i * node->elementsCount;
+          Value elem;
+
+          if (auto vecType = dyn_cast<VectorType>(originalResultType)) {
+            elem = rewriter.create<vector::ExtractStridedSliceOp>(
+                loc, newResult, offset, vecType.getNumElements(), 1);
+          } else {
+            elem = rewriter.create<vector::ExtractOp>(loc, newResult, offset);
+          }
+
           use.set(elem);
         }
       }
@@ -1109,8 +1119,9 @@ SLPGraph::vectorize(IRRewriter &rewriter,
           VectorType::get(numElements, getElementTypeAndCount(op)->first);
       Value result = rewriter.create<vector::LoadOp>(loc, vecType, getBase(op),
                                                      getIndices(op));
-      mapping.map(op->getResult(0), result);
-      handleNonVectorOutputs(result);
+      Value originalResult = op->getResult(0);
+      mapping.map(originalResult, result);
+      handleNonVectorOutputs(result, originalResult.getType());
     } else if (maybeWriteOp(op)) {
       handleNonVectorInputs(getValueToStore(op));
       Value val = mapping.lookupOrDefault(getValueToStore(op));
@@ -1133,7 +1144,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
       newOp->getResult(0).setType(resVectorType);
 
       mapping.map(op->getResults(), newOp->getResults());
-      handleNonVectorOutputs(newOp->getResult(0));
+      handleNonVectorOutputs(newOp->getResult(0), op->getResultTypes().front());
     } else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
       // We alredy verified index is valid during graph construction, so
       // do need to check `getExtractIndex` result.
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index a6108287551f4..38490ba4934a4 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -672,6 +672,32 @@ func.func @read_read_add_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
 }
 
 
+// CHECK-LABEL: func @read_read_add_add_vec
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_add_vec(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
+                                   (vector<2xi32>, vector<2xi32>){
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<4xi32>
+  // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK: return %[[V3]], %[[V4]] : vector<2xi32>, vector<2xi32>
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+
+  %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<2xi32>
+  %2 = vector.load %arg0[%c2] : memref<8xi32>, vector<2xi32>
+
+  %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<2xi32>
+  %6 = vector.load %arg1[%c2] : memref<8xi32>, vector<2xi32>
+
+  %8 = arith.addi %0, %4 : vector<2xi32>
+  %10 = arith.addi %2, %6 : vector<2xi32>
+
+  return %8, %10 : vector<2xi32>, vector<2xi32>
+}
+
 
 func.func private @use(i32)
 

>From 8578695b9a84880319e50b0bfc5c591bca094cc5 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 1 Jun 2025 12:58:42 +0200
Subject: [PATCH 52/52] vector handling

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 10 +++-
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 56 +++++++++++++++++++
 2 files changed, 64 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 9a389612567df..58c4c5b271458 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -1092,8 +1092,14 @@ SLPGraph::vectorize(IRRewriter &rewriter,
           Value elem;
 
           if (auto vecType = dyn_cast<VectorType>(originalResultType)) {
-            elem = rewriter.create<vector::ExtractStridedSliceOp>(
-                loc, newResult, offset, vecType.getNumElements(), 1);
+            assert(vecType.getRank() <= 1);
+            if (vecType.getRank() == 0) {
+              elem = rewriter.create<vector::ExtractOp>(loc, newResult, offset);
+              elem = rewriter.create<vector::SplatOp>(loc, vecType, elem);
+            } else {
+              elem = rewriter.create<vector::ExtractStridedSliceOp>(
+                  loc, newResult, offset, vecType.getNumElements(), 1);
+            }
           } else {
             elem = rewriter.create<vector::ExtractOp>(loc, newResult, offset);
           }
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 38490ba4934a4..29c077d7ab34f 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -699,6 +699,62 @@ func.func @read_read_add_add_vec(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
 }
 
 
+// CHECK-LABEL: func @read_read_add_add_vec1
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_add_vec1(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
+                                   (vector<1xi32>, vector<1xi32>){
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+  // CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+  // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<2xi32>
+  // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
+  // CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [1], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
+  // CHECK: return %[[V3]], %[[V4]] : vector<1xi32>, vector<1xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<1xi32>
+  %2 = vector.load %arg0[%c1] : memref<8xi32>, vector<1xi32>
+
+  %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<1xi32>
+  %6 = vector.load %arg1[%c1] : memref<8xi32>, vector<1xi32>
+
+  %8 = arith.addi %0, %4 : vector<1xi32>
+  %10 = arith.addi %2, %6 : vector<1xi32>
+
+  return %8, %10 : vector<1xi32>, vector<1xi32>
+}
+
+
+// CHECK-LABEL: func @read_read_add_add_vec0d
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_add_vec0d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
+                                   (vector<i32>, vector<i32>){
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+  // CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+  // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<2xi32>
+  // CHECK: %[[V3:.*]] = vector.extract %[[V2]][0] : i32 from vector<2xi32>
+  // CHECK: %[[V4:.*]] = vector.splat %[[V3]] : vector<i32>
+  // CHECK: %[[V5:.*]] = vector.extract %[[V2]][1] : i32 from vector<2xi32>
+  // CHECK: %[[V6:.*]] = vector.splat %[[V5]] : vector<i32>
+  // CHECK: return %[[V4]], %[[V6]] : vector<i32>, vector<i32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<i32>
+  %2 = vector.load %arg0[%c1] : memref<8xi32>, vector<i32>
+
+  %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<i32>
+  %6 = vector.load %arg1[%c1] : memref<8xi32>, vector<i32>
+
+  %8 = arith.addi %0, %4 : vector<i32>
+  %10 = arith.addi %2, %6 : vector<i32>
+
+  return %8, %10 : vector<i32>, vector<i32>
+}
+
+
 func.func private @use(i32)
 
 // CHECK-LABEL: func @read_read_add_write_interleaved_use



More information about the Mlir-commits mailing list