[Mlir-commits] [mlir] [mlir][vector] MLIR SLP vectorizer (PR #140469)

Ivan Butygin llvmlistbot at llvm.org
Mon May 19 02:00:43 PDT 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/140469

>From 36c7f1e004c3a872dff9005c7113b6c99760671c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:01:41 +0200
Subject: [PATCH 01/28] stubs

---
 .../mlir/Dialect/Vector/Transforms/Passes.h   |  3 +
 .../mlir/Dialect/Vector/Transforms/Passes.td  | 12 ++++
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 +
 .../Vector/Transforms/SLPVectorizer.cpp       | 63 +++++++++++++++++++
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 34 ++++++++++
 5 files changed, 113 insertions(+)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
 create mode 100644 mlir/test/Dialect/Vector/slp-vectorize.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..43112f084dc60 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -25,6 +25,9 @@ std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
     VectorMultiReductionLowering option =
         VectorMultiReductionLowering::InnerParallel);
 
+/// Creates a pass that implements the SLP vectorizer.
+std::unique_ptr<Pass> createSLPVectorizerPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..94ccd61cb5170 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,16 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
   ];
 }
 
+def SLPVectorizer : Pass<"slp-vectorizer", "ModuleOp"> {
+  let summary = "SLP Vectorizer Pass";
+  let description = [{
+    This pass implements the SLP (Superword Level Parallelism) vectorizer.
+    It detects consecutive operations that can be put together into vector
+    operations. The pass works bottom-up, across basic blocks, in search of
+    scalars to combine.
+  }];
+  let constructor = "mlir::vector::createSLPVectorizerPass()";
+  let dependentDialects = ["mlir::vector::VectorDialect"];
+}
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..37333b739bd86 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorStep.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
+  SLPVectorizer.cpp
   SubsetOpInterfaceImpl.cpp
   VectorDistribute.cpp
   VectorDropLeadUnitDim.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
new file mode 100644
index 0000000000000..e9f3b12bc7461
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -0,0 +1,63 @@
+//===- SLPVectorizer.cpp - SLP Vectorizer Pass ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the SLP vectorizer pass for MLIR. The pass attempts to
+// combine similar independent operations into vector operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "slp-vectorizer"
+
+namespace mlir {
+namespace vector {
+#define GEN_PASS_DEF_SLPVECTORIZER
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace vector
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// This pass implements the SLP vectorizer. It detects consecutive operations
+/// that can be put together into vector operations. The pass works bottom-up,
+/// across basic blocks, in search of scalars to combine.
+struct SLPVectorizerPass
+    : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void SLPVectorizerPass::runOnOperation() {
+  Operation *op = getOperation();
+  MLIRContext *context = &getContext();
+
+  // TODO: Implement SLP vectorization logic
+  // 1. Find candidate operations for vectorization
+  // 2. Build vectorization trees
+  // 3. Perform vectorization if profitable
+  // 4. Clean up scalar operations
+
+  LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass\n");
+  llvm::errs() << "Running SLP Vectorizer pass\n";
+}
+
+std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
+  return std::make_unique<SLPVectorizerPass>();
+}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
new file mode 100644
index 0000000000000..31543f3a76b2e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -test-slp-vectorization | FileCheck %s
+
+// CHECK-LABEL: func @basic_slp
+func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK: vector.transfer_read
+  // CHECK: arith.addi
+  // CHECK: vector.transfer_write
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+  %8 = arith.addi %0, %4 : i32
+  %9 = arith.addi %1, %5 : i32
+  %10 = arith.addi %2, %6 : i32
+  %11 = arith.addi %3, %7 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
+
+  return
+}

>From 2b4e64b0e7790b4b265c5924e1c0764ca1083a40 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:20:34 +0200
Subject: [PATCH 02/28] something working

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 113 +++++++++++++++++-
 mlir/test/Dialect/Vector/slp-vectorize.mlir   |   2 +-
 2 files changed, 108 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e9f3b12bc7461..b696f36c82eee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
@@ -34,27 +35,127 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
+/// A group of consecutive memory operations of the same type (load or store)
+/// that can potentially be vectorized together.
+struct MemoryOpGroup {
+  enum class Type { Load, Store };
+  Type type;
+  SmallVector<Operation *> ops;
+
+  MemoryOpGroup(Type t) : type(t) {}
+
+  bool isLoadGroup() const { return type == Type::Load; }
+  bool isStoreGroup() const { return type == Type::Store; }
+
+  size_t size() const { return ops.size(); }
+  bool empty() const { return ops.empty(); }
+};
+
 /// This pass implements the SLP vectorizer. It detects consecutive operations
 /// that can be put together into vector operations. The pass works bottom-up,
 /// across basic blocks, in search of scalars to combine.
 struct SLPVectorizerPass
     : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
   void runOnOperation() override;
+
+private:
+  /// Collect all memory operations in the block into groups.
+  /// Each group contains either all loads or all stores, uninterrupted by
+  /// operations of the other type.
+  SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
 };
 
 } // namespace
 
+SmallVector<MemoryOpGroup>
+SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
+  SmallVector<MemoryOpGroup> groups;
+  MemoryOpGroup *currentGroup = nullptr;
+
+  LLVM_DEBUG(llvm::dbgs() << "Scanning block for memory operations...\n");
+
+  for (Operation &op : block) {
+    LLVM_DEBUG(llvm::dbgs() << "Checking operation: " << op.getName() << "\n");
+
+    // Skip non-memory operations
+    if (!isa<memref::LoadOp, memref::StoreOp>(op)) {
+      LLVM_DEBUG(llvm::dbgs() << "  Not a memory operation\n");
+      continue;
+    }
+
+    bool isLoad = isa<memref::LoadOp>(op);
+    MemoryOpGroup::Type type =
+        isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
+
+    LLVM_DEBUG(llvm::dbgs()
+               << "  Found " << (isLoad ? "load" : "store") << " operation\n");
+
+    // Start a new group if:
+    // 1. We don't have a current group, or
+    // 2. The current operation is a different type than the current group
+    if (!currentGroup || currentGroup->type != type) {
+      LLVM_DEBUG(llvm::dbgs() << "  Starting new group\n");
+      groups.emplace_back(type);
+      currentGroup = &groups.back();
+    }
+
+    currentGroup->ops.push_back(&op);
+  }
+
+  // Remove empty groups
+  groups.erase(std::remove_if(groups.begin(), groups.end(),
+                              [](const MemoryOpGroup &g) { return g.empty(); }),
+               groups.end());
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "Found " << groups.size() << " memory operation groups:\n";
+    for (const auto &group : groups) {
+      llvm::dbgs() << "  Group type: "
+                   << (group.isLoadGroup() ? "Load" : "Store")
+                   << ", size: " << group.size() << "\n";
+    }
+  });
+
+  return groups;
+}
+
 void SLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
   MLIRContext *context = &getContext();
 
-  // TODO: Implement SLP vectorization logic
-  // 1. Find candidate operations for vectorization
-  // 2. Build vectorization trees
-  // 3. Perform vectorization if profitable
-  // 4. Clean up scalar operations
+  LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass on operation: "
+                          << op->getName() << "\n");
+
+  // Process each function in the module
+  for (Region &region : op->getRegions()) {
+    for (Block &block : region) {
+      for (Operation &op : block) {
+        // If this is a function, process its body
+        if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Processing function: " << funcOp.getName() << "\n");
+
+          // Process each block in the function
+          for (Block &funcBlock : funcOp.getBody()) {
+            // Collect memory operation groups
+            SmallVector<MemoryOpGroup> groups =
+                collectMemoryOpGroups(funcBlock);
+
+            LLVM_DEBUG({
+              llvm::dbgs() << "Found " << groups.size()
+                           << " memory operation groups:\n";
+              for (const auto &group : groups) {
+                llvm::dbgs() << "  Group type: "
+                             << (group.isLoadGroup() ? "Load" : "Store")
+                             << ", size: " << group.size() << "\n";
+              }
+            });
+          }
+        }
+      }
+    }
+  }
 
-  LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass\n");
   llvm::errs() << "Running SLP Vectorizer pass\n";
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 31543f3a76b2e..a07dd05dd16aa 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-slp-vectorization | FileCheck %s
+// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
 
 // CHECK-LABEL: func @basic_slp
 func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {

>From 14804ce4fa80df248e435998a40dfc953e463324 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:29:20 +0200
Subject: [PATCH 03/28] block walk

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 64 ++++---------------
 1 file changed, 11 insertions(+), 53 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index b696f36c82eee..bec5f9d90b21b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -72,29 +72,19 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
   SmallVector<MemoryOpGroup> groups;
   MemoryOpGroup *currentGroup = nullptr;
 
-  LLVM_DEBUG(llvm::dbgs() << "Scanning block for memory operations...\n");
-
   for (Operation &op : block) {
-    LLVM_DEBUG(llvm::dbgs() << "Checking operation: " << op.getName() << "\n");
-
     // Skip non-memory operations
-    if (!isa<memref::LoadOp, memref::StoreOp>(op)) {
-      LLVM_DEBUG(llvm::dbgs() << "  Not a memory operation\n");
+    if (!isa<memref::LoadOp, memref::StoreOp>(op))
       continue;
-    }
 
     bool isLoad = isa<memref::LoadOp>(op);
     MemoryOpGroup::Type type =
         isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
 
-    LLVM_DEBUG(llvm::dbgs()
-               << "  Found " << (isLoad ? "load" : "store") << " operation\n");
-
     // Start a new group if:
     // 1. We don't have a current group, or
     // 2. The current operation is a different type than the current group
     if (!currentGroup || currentGroup->type != type) {
-      LLVM_DEBUG(llvm::dbgs() << "  Starting new group\n");
       groups.emplace_back(type);
       currentGroup = &groups.back();
     }
@@ -107,15 +97,6 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
                               [](const MemoryOpGroup &g) { return g.empty(); }),
                groups.end());
 
-  LLVM_DEBUG({
-    llvm::dbgs() << "Found " << groups.size() << " memory operation groups:\n";
-    for (const auto &group : groups) {
-      llvm::dbgs() << "  Group type: "
-                   << (group.isLoadGroup() ? "Load" : "Store")
-                   << ", size: " << group.size() << "\n";
-    }
-  });
-
   return groups;
 }
 
@@ -123,40 +104,17 @@ void SLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
   MLIRContext *context = &getContext();
 
-  LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass on operation: "
-                          << op->getName() << "\n");
-
-  // Process each function in the module
-  for (Region &region : op->getRegions()) {
-    for (Block &block : region) {
-      for (Operation &op : block) {
-        // If this is a function, process its body
-        if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
-          LLVM_DEBUG(llvm::dbgs()
-                     << "Processing function: " << funcOp.getName() << "\n");
-
-          // Process each block in the function
-          for (Block &funcBlock : funcOp.getBody()) {
-            // Collect memory operation groups
-            SmallVector<MemoryOpGroup> groups =
-                collectMemoryOpGroups(funcBlock);
-
-            LLVM_DEBUG({
-              llvm::dbgs() << "Found " << groups.size()
-                           << " memory operation groups:\n";
-              for (const auto &group : groups) {
-                llvm::dbgs() << "  Group type: "
-                             << (group.isLoadGroup() ? "Load" : "Store")
-                             << ", size: " << group.size() << "\n";
-              }
-            });
-          }
-        }
-      }
-    }
-  }
+  // Walk all blocks recursively
+  op->walk([&](Block *block) {
+    LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
+                            << block->getParentOp()->getName() << "\n");
 
-  llvm::errs() << "Running SLP Vectorizer pass\n";
+    // Collect memory operation groups
+    SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
+
+    LLVM_DEBUG(llvm::dbgs() << "Found " << groups.size()
+                            << " memory operation groups in block\n");
+  });
 }
 
 std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {

>From b2c5c5cdd73aac4374bc6015159ec08f95b06082 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:10:20 +0200
Subject: [PATCH 04/28] contiguous groups

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 106 +++++++++++++++++-
 1 file changed, 104 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index bec5f9d90b21b..f46dc71537ef3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -51,6 +51,96 @@ struct MemoryOpGroup {
   bool empty() const { return ops.empty(); }
 };
 
+// Extract contiguous groups from a MemoryOpGroup
+SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
+  SmallVector<MemoryOpGroup> result;
+  if (group.ops.empty())
+    return result;
+
+  // Keep track of which operations we've processed
+  DenseSet<Operation *> processedOps;
+
+  // Process each operation
+  for (Operation *op : group.ops) {
+    // Skip if we've already processed this operation
+    if (processedOps.contains(op))
+      continue;
+
+    // Get base and index of current operation
+    Value base;
+    int64_t index = -1;
+    if (group.isLoadGroup()) {
+      auto loadOp = cast<memref::LoadOp>(op);
+      if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
+        index = *value;
+        base = loadOp.getMemRef();
+      }
+    } else {
+      auto storeOp = cast<memref::StoreOp>(op);
+      if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
+        index = *value;
+        base = storeOp.getMemRef();
+      }
+    }
+    if (index == -1)
+      continue;
+
+    // Start a new group with this operation
+    result.emplace_back(group.type);
+    MemoryOpGroup &currentGroup = result.back();
+    currentGroup.ops.push_back(op);
+    processedOps.insert(op);
+
+    LLVM_DEBUG(llvm::dbgs() << "Starting new group at base " << base
+                            << " index " << index << "\n");
+
+    // Try to find operations with adjacent indices
+    bool foundMore;
+    do {
+      foundMore = false;
+      // Look for operations with index+1
+      for (Operation *otherOp : group.ops) {
+        if (processedOps.contains(otherOp))
+          continue;
+
+        Value otherBase;
+        int64_t otherIndex = -1;
+        if (group.isLoadGroup()) {
+          auto loadOp = cast<memref::LoadOp>(otherOp);
+          if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
+            otherIndex = *value;
+            otherBase = loadOp.getMemRef();
+          }
+        } else {
+          auto storeOp = cast<memref::StoreOp>(otherOp);
+          if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
+            otherIndex = *value;
+            otherBase = storeOp.getMemRef();
+          }
+        }
+
+        // Check if this operation has the same base and adjacent index
+        if (otherIndex != -1 && otherBase == base &&
+            otherIndex == currentGroup.ops.size()) {
+          currentGroup.ops.push_back(otherOp);
+          processedOps.insert(otherOp);
+          foundMore = true;
+          LLVM_DEBUG(llvm::dbgs()
+                     << "Added operation with index " << otherIndex << "\n");
+          break;
+        }
+      }
+    } while (foundMore);
+  }
+
+  // Remove empty groups
+  result.erase(std::remove_if(result.begin(), result.end(),
+                              [](const MemoryOpGroup &g) { return g.empty(); }),
+               result.end());
+
+  return result;
+}
+
 /// This pass implements the SLP vectorizer. It detects consecutive operations
 /// that can be put together into vector operations. The pass works bottom-up,
 /// across basic blocks, in search of scalars to combine.
@@ -112,8 +202,20 @@ void SLPVectorizerPass::runOnOperation() {
     // Collect memory operation groups
     SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
 
-    LLVM_DEBUG(llvm::dbgs() << "Found " << groups.size()
-                            << " memory operation groups in block\n");
+    // Process each group to find contiguous sequences
+    for (const auto &group : groups) {
+      SmallVector<MemoryOpGroup> contiguousGroups =
+          extractContiguousGroups(group);
+      LLVM_DEBUG({
+        llvm::dbgs() << "Found " << contiguousGroups.size()
+                     << " contiguous groups in "
+                     << (group.isLoadGroup() ? "load" : "store") << " group\n";
+        for (const auto &contigGroup : contiguousGroups) {
+          llvm::dbgs() << "  Contiguous group with " << contigGroup.size()
+                       << " operations\n";
+        }
+      });
+    }
   });
 }
 

>From e01213d8ec4284b24732050f14caca64d18a0123 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:14:53 +0200
Subject: [PATCH 05/28] refac

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 55 ++++++++-----------
 1 file changed, 22 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index f46dc71537ef3..9a0ba5264bc40 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -51,6 +51,18 @@ struct MemoryOpGroup {
   bool empty() const { return ops.empty(); }
 };
 
+// Helper function to extract base and index from a memory operation
+std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
+    if (auto value = getConstantIntValue(loadOp.getIndices().front()))
+      return std::make_pair(loadOp.getMemRef(), *value);
+  } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
+    if (auto value = getConstantIntValue(storeOp.getIndices().front()))
+      return std::make_pair(storeOp.getMemRef(), *value);
+  }
+  return std::nullopt;
+}
+
 // Extract contiguous groups from a MemoryOpGroup
 SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
   SmallVector<MemoryOpGroup> result;
@@ -67,24 +79,12 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
       continue;
 
     // Get base and index of current operation
-    Value base;
-    int64_t index = -1;
-    if (group.isLoadGroup()) {
-      auto loadOp = cast<memref::LoadOp>(op);
-      if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
-        index = *value;
-        base = loadOp.getMemRef();
-      }
-    } else {
-      auto storeOp = cast<memref::StoreOp>(op);
-      if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
-        index = *value;
-        base = storeOp.getMemRef();
-      }
-    }
-    if (index == -1)
+    auto baseAndIndex = getBaseAndIndex(op);
+    if (!baseAndIndex)
       continue;
 
+    auto [base, index] = *baseAndIndex;
+
     // Start a new group with this operation
     result.emplace_back(group.type);
     MemoryOpGroup &currentGroup = result.back();
@@ -103,25 +103,14 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
         if (processedOps.contains(otherOp))
           continue;
 
-        Value otherBase;
-        int64_t otherIndex = -1;
-        if (group.isLoadGroup()) {
-          auto loadOp = cast<memref::LoadOp>(otherOp);
-          if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
-            otherIndex = *value;
-            otherBase = loadOp.getMemRef();
-          }
-        } else {
-          auto storeOp = cast<memref::StoreOp>(otherOp);
-          if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
-            otherIndex = *value;
-            otherBase = storeOp.getMemRef();
-          }
-        }
+        auto otherBaseAndIndex = getBaseAndIndex(otherOp);
+        if (!otherBaseAndIndex)
+          continue;
+
+        auto [otherBase, otherIndex] = *otherBaseAndIndex;
 
         // Check if this operation has the same base and adjacent index
-        if (otherIndex != -1 && otherBase == base &&
-            otherIndex == currentGroup.ops.size()) {
+        if (otherBase == base && otherIndex == currentGroup.ops.size()) {
           currentGroup.ops.push_back(otherOp);
           processedOps.insert(otherOp);
           foundMore = true;

>From 99158188e89ab1d35947370e8f5be7369d875b97 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:35:42 +0200
Subject: [PATCH 06/28] SLPGraph

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 162 ++++++++++++++++++
 1 file changed, 162 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 9a0ba5264bc40..4355dc33648c0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -130,6 +130,160 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
   return result;
 }
 
+/// A node in the SLP graph representing a vectorizable operation
+struct SLPGraphNode {
+  Operation *op;
+  DenseSet<SLPGraphNode *> users;
+  DenseSet<SLPGraphNode *> operands;
+  bool isRoot = false;
+
+  SLPGraphNode(Operation *op) : op(op) {}
+};
+
+/// A graph of vectorizable operations
+class SLPGraph {
+public:
+  SLPGraph() = default;
+  ~SLPGraph() {
+    for (auto *node : nodes)
+      delete node;
+  }
+
+  /// Add a new node to the graph
+  SLPGraphNode *addNode(Operation *op) {
+    nodes.push_back(new SLPGraphNode(op));
+    return nodes.back();
+  }
+
+  /// Add a root node (memory operation)
+  SLPGraphNode *addRoot(Operation *op) {
+    auto *node = addNode(op);
+    node->isRoot = true;
+    return node;
+  }
+
+  /// Add a dependency edge between nodes
+  void addEdge(SLPGraphNode *from, SLPGraphNode *to) {
+    from->users.insert(to);
+    to->operands.insert(from);
+  }
+
+  /// Get all root nodes
+  SmallVector<SLPGraphNode *> getRoots() const {
+    SmallVector<SLPGraphNode *> roots;
+    for (auto *node : nodes)
+      if (node->isRoot)
+        roots.push_back(node);
+    return roots;
+  }
+
+  /// Print the graph structure
+  void print() const {
+    llvm::dbgs() << "SLP Graph Structure:\n";
+    llvm::dbgs() << "===================\n";
+
+    // First print all roots
+    llvm::dbgs() << "Roots:\n";
+    for (auto *node : nodes) {
+      if (!node->isRoot)
+        continue;
+      llvm::dbgs() << "  " << *node->op << "\n";
+      llvm::dbgs() << "    Users: ";
+      for (auto *user : node->users) {
+        llvm::dbgs() << "\n      " << *user->op;
+      }
+      llvm::dbgs() << "\n";
+    }
+
+    // Then print all non-root nodes
+    llvm::dbgs() << "\nNon-root nodes:\n";
+    for (auto *node : nodes) {
+      if (node->isRoot)
+        continue;
+      llvm::dbgs() << "  " << *node->op << "\n";
+      llvm::dbgs() << "    Operands: ";
+      for (auto *operand : node->operands) {
+        llvm::dbgs() << "\n      " << *operand->op;
+      }
+      llvm::dbgs() << "\n    Users: ";
+      for (auto *user : node->users) {
+        llvm::dbgs() << "\n      " << *user->op;
+      }
+      llvm::dbgs() << "\n";
+    }
+    llvm::dbgs() << "===================\n";
+  }
+
+private:
+  SmallVector<SLPGraphNode *> nodes;
+};
+
+/// Build the SLP graph starting from memory operation roots
+SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+  SLPGraph graph;
+  DenseMap<Operation *, SLPGraphNode *> opToNode;
+
+  // First, add all memory operations as roots
+  for (const auto &group : rootGroups) {
+    for (Operation *op : group.ops) {
+      opToNode[op] = graph.addRoot(op);
+    }
+  }
+
+  // Process each root group to build the graph
+  for (const auto &group : rootGroups) {
+    for (Operation *rootOp : group.ops) {
+      // Get the value produced by this memory operation
+      Value rootValue = group.isLoadGroup()
+                            ? cast<memref::LoadOp>(rootOp).getResult()
+                            : cast<memref::StoreOp>(rootOp).getValue();
+
+      // Find all users of this value
+      for (Operation *user : rootValue.getUsers()) {
+        // Skip if we've already processed this operation
+        if (opToNode.contains(user))
+          continue;
+
+        // Check if this is a vectorizable operation
+        if (isa<arith::AddFOp, arith::AddIOp, arith::SubFOp, arith::SubIOp,
+                arith::MulFOp, arith::MulIOp>(user)) {
+          // Check if at least one other operand is already in the graph
+          bool hasGraphOperand = false;
+          for (Value operand : user->getOperands()) {
+            if (operand == rootValue)
+              continue;
+            if (auto *defOp = operand.getDefiningOp()) {
+              if (opToNode.contains(defOp)) {
+                hasGraphOperand = true;
+                break;
+              }
+            }
+          }
+
+          // Only add the operation if it has at least one other operand in the
+          // graph
+          if (hasGraphOperand) {
+            auto *node = graph.addNode(user);
+            opToNode[user] = node;
+            graph.addEdge(opToNode[rootOp], node);
+
+            // Add edges from other operands that are in the graph
+            for (Value operand : user->getOperands()) {
+              if (auto *defOp = operand.getDefiningOp()) {
+                if (opToNode.contains(defOp)) {
+                  graph.addEdge(opToNode[defOp], node);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
+  return graph;
+}
+
 /// This pass implements the SLP vectorizer. It detects consecutive operations
 /// that can be put together into vector operations. The pass works bottom-up,
 /// across basic blocks, in search of scalars to combine.
@@ -192,6 +346,7 @@ void SLPVectorizerPass::runOnOperation() {
     SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
 
     // Process each group to find contiguous sequences
+    SmallVector<MemoryOpGroup> rootGroups;
     for (const auto &group : groups) {
       SmallVector<MemoryOpGroup> contiguousGroups =
           extractContiguousGroups(group);
@@ -204,7 +359,14 @@ void SLPVectorizerPass::runOnOperation() {
                        << " operations\n";
         }
       });
+      rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
     }
+
+    // Build the SLP graph from root groups
+    SLPGraph graph = buildSLPGraph(rootGroups);
+
+    // Print the graph structure
+    LLVM_DEBUG(graph.print());
   });
 }
 

>From 4a5409137117ac56f1b4d6b8f8f28f7eb6291f22 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:53:58 +0200
Subject: [PATCH 07/28] SLPGraph

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 129 ++++++------------
 1 file changed, 38 insertions(+), 91 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 4355dc33648c0..3c4fc3a377244 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -130,29 +130,28 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
   return result;
 }
 
-/// A node in the SLP graph representing a vectorizable operation
+/// A node in the SLP graph representing a group of vectorizable operations
 struct SLPGraphNode {
-  Operation *op;
+  SmallVector<Operation *> ops;
   DenseSet<SLPGraphNode *> users;
   DenseSet<SLPGraphNode *> operands;
   bool isRoot = false;
 
-  SLPGraphNode(Operation *op) : op(op) {}
+  SLPGraphNode() = default;
+  SLPGraphNode(Operation *op) { ops.push_back(op); }
+  void addOp(Operation *op) { ops.push_back(op); }
 };
 
 /// A graph of vectorizable operations
 class SLPGraph {
 public:
   SLPGraph() = default;
-  ~SLPGraph() {
-    for (auto *node : nodes)
-      delete node;
-  }
+  ~SLPGraph() = default;
 
   /// Add a new node to the graph
   SLPGraphNode *addNode(Operation *op) {
-    nodes.push_back(new SLPGraphNode(op));
-    return nodes.back();
+    nodes.push_back(std::make_unique<SLPGraphNode>(op));
+    return nodes.back().get();
   }
 
   /// Add a root node (memory operation)
@@ -171,9 +170,9 @@ class SLPGraph {
   /// Get all root nodes
   SmallVector<SLPGraphNode *> getRoots() const {
     SmallVector<SLPGraphNode *> roots;
-    for (auto *node : nodes)
+    for (const auto &node : nodes)
       if (node->isRoot)
-        roots.push_back(node);
+        roots.push_back(node.get());
     return roots;
   }
 
@@ -184,30 +183,50 @@ class SLPGraph {
 
     // First print all roots
     llvm::dbgs() << "Roots:\n";
-    for (auto *node : nodes) {
+    for (const auto &node : nodes) {
       if (!node->isRoot)
         continue;
-      llvm::dbgs() << "  " << *node->op << "\n";
+      llvm::dbgs() << "  "
+                   << (isa<memref::LoadOp>(node->ops[0]) ? "LOAD" : "STORE")
+                   << " group with " << node->ops.size() << " operations:\n";
+      for (auto *op : node->ops) {
+        llvm::dbgs() << "    " << *op << "\n";
+      }
       llvm::dbgs() << "    Users: ";
       for (auto *user : node->users) {
-        llvm::dbgs() << "\n      " << *user->op;
+        llvm::dbgs() << "\n      Group with " << user->ops.size()
+                     << " operations:";
+        for (auto *op : user->ops) {
+          llvm::dbgs() << "\n        " << *op;
+        }
       }
       llvm::dbgs() << "\n";
     }
 
     // Then print all non-root nodes
     llvm::dbgs() << "\nNon-root nodes:\n";
-    for (auto *node : nodes) {
+    for (const auto &node : nodes) {
       if (node->isRoot)
         continue;
-      llvm::dbgs() << "  " << *node->op << "\n";
+      llvm::dbgs() << "  Group with " << node->ops.size() << " operations:\n";
+      for (auto *op : node->ops) {
+        llvm::dbgs() << "    " << *op << "\n";
+      }
       llvm::dbgs() << "    Operands: ";
       for (auto *operand : node->operands) {
-        llvm::dbgs() << "\n      " << *operand->op;
+        llvm::dbgs() << "\n      Group with " << operand->ops.size()
+                     << " operations:";
+        for (auto *op : operand->ops) {
+          llvm::dbgs() << "\n        " << *op;
+        }
       }
       llvm::dbgs() << "\n    Users: ";
       for (auto *user : node->users) {
-        llvm::dbgs() << "\n      " << *user->op;
+        llvm::dbgs() << "\n      Group with " << user->ops.size()
+                     << " operations:";
+        for (auto *op : user->ops) {
+          llvm::dbgs() << "\n        " << *op;
+        }
       }
       llvm::dbgs() << "\n";
     }
@@ -215,75 +234,9 @@ class SLPGraph {
   }
 
 private:
-  SmallVector<SLPGraphNode *> nodes;
+  SmallVector<std::unique_ptr<SLPGraphNode>> nodes;
 };
 
-/// Build the SLP graph starting from memory operation roots
-SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
-  SLPGraph graph;
-  DenseMap<Operation *, SLPGraphNode *> opToNode;
-
-  // First, add all memory operations as roots
-  for (const auto &group : rootGroups) {
-    for (Operation *op : group.ops) {
-      opToNode[op] = graph.addRoot(op);
-    }
-  }
-
-  // Process each root group to build the graph
-  for (const auto &group : rootGroups) {
-    for (Operation *rootOp : group.ops) {
-      // Get the value produced by this memory operation
-      Value rootValue = group.isLoadGroup()
-                            ? cast<memref::LoadOp>(rootOp).getResult()
-                            : cast<memref::StoreOp>(rootOp).getValue();
-
-      // Find all users of this value
-      for (Operation *user : rootValue.getUsers()) {
-        // Skip if we've already processed this operation
-        if (opToNode.contains(user))
-          continue;
-
-        // Check if this is a vectorizable operation
-        if (isa<arith::AddFOp, arith::AddIOp, arith::SubFOp, arith::SubIOp,
-                arith::MulFOp, arith::MulIOp>(user)) {
-          // Check if at least one other operand is already in the graph
-          bool hasGraphOperand = false;
-          for (Value operand : user->getOperands()) {
-            if (operand == rootValue)
-              continue;
-            if (auto *defOp = operand.getDefiningOp()) {
-              if (opToNode.contains(defOp)) {
-                hasGraphOperand = true;
-                break;
-              }
-            }
-          }
-
-          // Only add the operation if it has at least one other operand in the
-          // graph
-          if (hasGraphOperand) {
-            auto *node = graph.addNode(user);
-            opToNode[user] = node;
-            graph.addEdge(opToNode[rootOp], node);
-
-            // Add edges from other operands that are in the graph
-            for (Value operand : user->getOperands()) {
-              if (auto *defOp = operand.getDefiningOp()) {
-                if (opToNode.contains(defOp)) {
-                  graph.addEdge(opToNode[defOp], node);
-                }
-              }
-            }
-          }
-        }
-      }
-    }
-  }
-
-  return graph;
-}
-
 /// This pass implements the SLP vectorizer. It detects consecutive operations
 /// that can be put together into vector operations. The pass works bottom-up,
 /// across basic blocks, in search of scalars to combine.
@@ -361,12 +314,6 @@ void SLPVectorizerPass::runOnOperation() {
       });
       rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
     }
-
-    // Build the SLP graph from root groups
-    SLPGraph graph = buildSLPGraph(rootGroups);
-
-    // Print the graph structure
-    LLVM_DEBUG(graph.print());
   });
 }
 

>From 96e0fe8232a27d47017b4e7366913b4e5966ef78 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 01:04:27 +0200
Subject: [PATCH 08/28] work

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 48 ++++++++++++++++---
 1 file changed, 41 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3c4fc3a377244..8e49b622ac39b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -138,8 +138,8 @@ struct SLPGraphNode {
   bool isRoot = false;
 
   SLPGraphNode() = default;
-  SLPGraphNode(Operation *op) { ops.push_back(op); }
-  void addOp(Operation *op) { ops.push_back(op); }
+  SLPGraphNode(ArrayRef<Operation *> operations)
+      : ops(operations.begin(), operations.end()) {}
 };
 
 /// A graph of vectorizable operations
@@ -148,15 +148,23 @@ class SLPGraph {
   SLPGraph() = default;
   ~SLPGraph() = default;
 
+  // Delete copy constructor and assignment operator
+  SLPGraph(const SLPGraph &) = delete;
+  SLPGraph &operator=(const SLPGraph &) = delete;
+
+  // Allow move operations
+  SLPGraph(SLPGraph &&) = default;
+  SLPGraph &operator=(SLPGraph &&) = default;
+
   /// Add a new node to the graph
-  SLPGraphNode *addNode(Operation *op) {
-    nodes.push_back(std::make_unique<SLPGraphNode>(op));
+  SLPGraphNode *addNode(ArrayRef<Operation *> operations) {
+    nodes.push_back(std::make_unique<SLPGraphNode>(operations));
     return nodes.back().get();
   }
 
   /// Add a root node (memory operation)
-  SLPGraphNode *addRoot(Operation *op) {
-    auto *node = addNode(op);
+  SLPGraphNode *addRoot(ArrayRef<Operation *> operations) {
+    auto *node = addNode(operations);
     node->isRoot = true;
     return node;
   }
@@ -251,7 +259,25 @@ struct SLPVectorizerPass
   SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
 };
 
-} // namespace
+/// Build the SLP graph starting from memory operation groups
+SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+  SLPGraph graph;
+
+  // First, create nodes for each contiguous memory operation group
+  for (const auto &group : rootGroups) {
+    // Create a new node for this group
+    auto *node = graph.addRoot(group.ops);
+    node->isRoot = true;
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "Created " << (group.isLoadGroup() ? "LOAD" : "STORE")
+                   << " group node with " << node->ops.size()
+                   << " operations\n";
+    });
+  }
+
+  return graph;
+}
 
 SmallVector<MemoryOpGroup>
 SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
@@ -314,9 +340,17 @@ void SLPVectorizerPass::runOnOperation() {
       });
       rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
     }
+
+    // Build the SLP graph from root groups
+    SLPGraph graph = buildSLPGraph(rootGroups);
+
+    // Print the graph structure
+    LLVM_DEBUG(graph.print());
   });
 }
 
+} // namespace
+
 std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
   return std::make_unique<SLPVectorizerPass>();
 }

>From c19e55782daed0c57bb4604192dc73fa50ff4709 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 12:55:28 +0200
Subject: [PATCH 09/28] fingerprinting

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 170 ++++++++++++++++--
 1 file changed, 158 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 8e49b622ac39b..3e6a4ca05f87d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/SHA1.h"
 
 #define DEBUG_TYPE "slp-vectorizer"
 
@@ -64,7 +65,8 @@ std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
 }
 
 // Extract contiguous groups from a MemoryOpGroup
-SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
+static SmallVector<MemoryOpGroup>
+extractContiguousGroups(const MemoryOpGroup &group) {
   SmallVector<MemoryOpGroup> result;
   if (group.ops.empty())
     return result;
@@ -133,8 +135,8 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
 /// A node in the SLP graph representing a group of vectorizable operations
 struct SLPGraphNode {
   SmallVector<Operation *> ops;
-  DenseSet<SLPGraphNode *> users;
-  DenseSet<SLPGraphNode *> operands;
+  llvm::SmallDenseSet<SLPGraphNode *> users;
+  llvm::SmallDenseSet<SLPGraphNode *> operands;
   bool isRoot = false;
 
   SLPGraphNode() = default;
@@ -148,11 +150,9 @@ class SLPGraph {
   SLPGraph() = default;
   ~SLPGraph() = default;
 
-  // Delete copy constructor and assignment operator
   SLPGraph(const SLPGraph &) = delete;
   SLPGraph &operator=(const SLPGraph &) = delete;
 
-  // Allow move operations
   SLPGraph(SLPGraph &&) = default;
   SLPGraph &operator=(SLPGraph &&) = default;
 
@@ -259,21 +259,168 @@ struct SLPVectorizerPass
   SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
 };
 
+static bool isVectorizable(Operation *op) {
+  return OpTrait::hasElementwiseMappableTraits(op);
+}
+
+using Fingerprint = std::array<uint8_t, 20>;
+
+template <typename T>
+static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
+  hasher.update(
+      ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
+}
+
+struct OperationsFingerprint {
+  OperationsFingerprint(
+      const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode)
+      : opToNode(opToNode) {}
+
+  Fingerprint getFingerprint(Operation *op) {
+    auto it = fingerprints.find(op);
+    if (it != fingerprints.end())
+      return it->second;
+
+    SmallVector<Operation *> worklist;
+    SmallVector<Operation *> toposortedOps;
+    worklist.emplace_back(op);
+    while (!worklist.empty()) {
+      Operation *op = worklist.pop_back_val();
+      toposortedOps.emplace_back(op);
+      if (opToNode.contains(op))
+        continue;
+
+      for (Value operand : op->getOperands()) {
+        auto *defOp = operand.getDefiningOp();
+        if (!defOp || !isVectorizable(defOp))
+          continue;
+
+        toposortedOps.emplace_back(defOp);
+        worklist.emplace_back(defOp);
+      }
+    }
+
+    for (Operation *op : llvm::reverse(toposortedOps)) {
+      llvm::SHA1 hasher;
+      addDataToHash(hasher, op->getName().getTypeID());
+      addDataToHash(hasher, op->getRawDictionaryAttrs());
+      addDataToHash(hasher, op->hashProperties());
+      for (Value operand : op->getOperands()) {
+        auto *defOp = operand.getDefiningOp();
+        if (!defOp)
+          continue;
+
+        auto it1 = opToNode.find(defOp);
+        if (it1 != opToNode.end()) {
+          addDataToHash(hasher, it1->second);
+          continue;
+        }
+
+        auto it2 = fingerprints.find(defOp);
+        if (it2 != fingerprints.end()) {
+          addDataToHash(hasher, it2->second);
+          continue;
+        }
+      }
+      fingerprints[op] = hasher.result();
+    }
+
+    return fingerprints[op];
+  }
+
+  void invalidate(Operation *op) {
+    if (fingerprints.contains(op))
+      fingerprints.clear();
+  }
+
+  const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode;
+  DenseMap<Operation *, Fingerprint> fingerprints;
+};
+
+static bool isEquivalent(Operation *op1, Operation *op2) {
+  if (op1->getName() != op2->getName())
+    return false;
+
+  if (op1->getRawDictionaryAttrs() != op2->getRawDictionaryAttrs())
+    return false;
+
+  return true;
+}
+
 /// Build the SLP graph starting from memory operation groups
-SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
   SLPGraph graph;
+  llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
+
+  SmallVector<SLPGraphNode *> worklist;
 
   // First, create nodes for each contiguous memory operation group
   for (const auto &group : rootGroups) {
-    // Create a new node for this group
     auto *node = graph.addRoot(group.ops);
-    node->isRoot = true;
+    for (Operation *op : group.ops)
+      opToNode[op] = node;
+
+    worklist.push_back(node);
 
     LLVM_DEBUG({
-      llvm::dbgs() << "Created " << (group.isLoadGroup() ? "LOAD" : "STORE")
-                   << " group node with " << node->ops.size()
-                   << " operations\n";
+      llvm::dbgs() << "Created root group node with " << node->ops.size()
+                   << " operations of type "
+                   << (group.type == MemoryOpGroup::Type::Load ? "Load"
+                                                               : "Store")
+                   << "\n";
     });
+
+    OperationsFingerprint fingerprints(opToNode);
+
+    auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
+      Operation *user = use.getOwner();
+      if (opToNode.contains(user) || !isVectorizable(user))
+        return;
+
+      Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
+
+      SmallVector<Operation *> currentOps;
+      currentOps.emplace_back(user);
+      for (Operation *op : ArrayRef(node->ops).drop_front()) {
+        Operation *found = nullptr;
+        for (OpOperand &opUse : op->getUses()) {
+          if (opUse.getOperandNumber() != use.getOperandNumber())
+            continue;
+
+          Operation *useOwner = opUse.getOwner();
+          if (!isEquivalent(useOwner, user) ||
+              fingerprints.getFingerprint(useOwner) != expectedFingerprint)
+            continue;
+
+          found = useOwner;
+          break;
+        }
+        if (!found)
+          break;
+
+        currentOps.push_back(found);
+      }
+
+      if (currentOps.size() == 1)
+        return;
+
+      auto *newNode = graph.addNode(currentOps);
+      graph.addEdge(node, newNode);
+      for (Operation *op : currentOps) {
+        opToNode[op] = newNode;
+        fingerprints.invalidate(op);
+      }
+
+      worklist.push_back(newNode);
+    };
+
+    while (!worklist.empty()) {
+      SLPGraphNode *node = worklist.pop_back_val();
+
+      Operation *op = node->ops.front();
+      for (OpOperand &use : op->getUses())
+        processUse(node, use);
+    }
   }
 
   return graph;
@@ -314,7 +461,6 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
 
 void SLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
-  MLIRContext *context = &getContext();
 
   // Walk all blocks recursively
   op->walk([&](Block *block) {

>From dfdfc8948cdfd70d068234d310a422eeaef45e3a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 13:58:50 +0200
Subject: [PATCH 10/28] graph

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 106 ++++++++++--------
 1 file changed, 62 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3e6a4ca05f87d..28c53efea7512 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -185,7 +185,7 @@ class SLPGraph {
   }
 
   /// Print the graph structure
-  void print() const {
+  [[maybe_unused]] void print() const {
     llvm::dbgs() << "SLP Graph Structure:\n";
     llvm::dbgs() << "===================\n";
 
@@ -348,7 +348,12 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
 }
 
 /// Build the SLP graph starting from memory operation groups
-static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
+  if (rootGroups.empty())
+    return SLPGraph();
+
+  LLVM_DEBUG(llvm::dbgs() << "=== Building SLP graph from " << rootGroups.size()
+                          << " root groups ===\n");
   SLPGraph graph;
   llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
 
@@ -365,61 +370,74 @@ static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
     LLVM_DEBUG({
       llvm::dbgs() << "Created root group node with " << node->ops.size()
                    << " operations of type "
-                   << (group.type == MemoryOpGroup::Type::Load ? "Load"
-                                                               : "Store")
-                   << "\n";
+                   << (group.isLoadGroup() ? "Load" : "Store") << "\n";
     });
+  }
 
-    OperationsFingerprint fingerprints(opToNode);
-
-    auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
-      Operation *user = use.getOwner();
-      if (opToNode.contains(user) || !isVectorizable(user))
-        return;
+  OperationsFingerprint fingerprints(opToNode);
+
+  auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
+    Operation *user = use.getOwner();
+    auto it = opToNode.find(user);
+    if (it != opToNode.end()) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "  Adding edge from " << node->ops.front()->getName()
+                 << " to " << it->first->getName() << "\n");
+      graph.addEdge(node, it->second);
+      return;
+    }
 
-      Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
+    if (!isVectorizable(user))
+      return;
 
-      SmallVector<Operation *> currentOps;
-      currentOps.emplace_back(user);
-      for (Operation *op : ArrayRef(node->ops).drop_front()) {
-        Operation *found = nullptr;
-        for (OpOperand &opUse : op->getUses()) {
-          if (opUse.getOperandNumber() != use.getOperandNumber())
-            continue;
+    Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
 
-          Operation *useOwner = opUse.getOwner();
-          if (!isEquivalent(useOwner, user) ||
-              fingerprints.getFingerprint(useOwner) != expectedFingerprint)
-            continue;
+    SmallVector<Operation *> currentOps;
+    currentOps.emplace_back(user);
+    for (Operation *op : ArrayRef(node->ops).drop_front()) {
+      Operation *found = nullptr;
+      for (OpOperand &opUse : op->getUses()) {
+        if (opUse.getOperandNumber() != use.getOperandNumber())
+          continue;
 
-          found = useOwner;
-          break;
-        }
-        if (!found)
-          break;
+        Operation *useOwner = opUse.getOwner();
+        if (!isEquivalent(useOwner, user) ||
+            fingerprints.getFingerprint(useOwner) != expectedFingerprint)
+          continue;
 
-        currentOps.push_back(found);
+        found = useOwner;
+        break;
       }
+      if (!found)
+        break;
 
-      if (currentOps.size() == 1)
-        return;
+      currentOps.push_back(found);
+    }
 
-      auto *newNode = graph.addNode(currentOps);
-      graph.addEdge(node, newNode);
-      for (Operation *op : currentOps) {
-        opToNode[op] = newNode;
-        fingerprints.invalidate(op);
-      }
+    if (currentOps.size() == 1)
+      return;
 
-      worklist.push_back(newNode);
-    };
+    auto *newNode = graph.addNode(currentOps);
+    graph.addEdge(node, newNode);
+    for (Operation *op : currentOps) {
+      opToNode[op] = newNode;
+      fingerprints.invalidate(op);
+    }
 
-    while (!worklist.empty()) {
-      SLPGraphNode *node = worklist.pop_back_val();
+    worklist.push_back(newNode);
+  };
+
+  while (!worklist.empty()) {
+    SLPGraphNode *node = worklist.pop_back_val();
+    LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->ops.size()
+                            << " operations, first op: "
+                            << node->ops.front()->getName() << "\n");
 
-      Operation *op = node->ops.front();
-      for (OpOperand &use : op->getUses())
-        processUse(node, use);
+    Operation *op = node->ops.front();
+    for (OpOperand &use : op->getUses()) {
+      processUse(node, use);
+      LLVM_DEBUG(llvm::dbgs() << "  Processing use in operation: "
+                              << use.getOwner()->getName() << "\n");
     }
   }
 

>From 3272e0192e6c6a4c2abd6e170ffa79ede3aa07da Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 14:05:07 +0200
Subject: [PATCH 11/28] refac

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 41 ++++++++++---------
 1 file changed, 22 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 28c53efea7512..e3b39ba10373c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -159,7 +159,10 @@ class SLPGraph {
   /// Add a new node to the graph
   SLPGraphNode *addNode(ArrayRef<Operation *> operations) {
     nodes.push_back(std::make_unique<SLPGraphNode>(operations));
-    return nodes.back().get();
+    auto *node = nodes.back().get();
+    for (Operation *op : operations)
+      opToNode[op] = node;
+    return node;
   }
 
   /// Add a root node (memory operation)
@@ -184,6 +187,12 @@ class SLPGraph {
     return roots;
   }
 
+  /// Get the node associated with an operation
+  SLPGraphNode *getNodeForOp(Operation *op) const {
+    auto it = opToNode.find(op);
+    return it != opToNode.end() ? it->second : nullptr;
+  }
+
   /// Print the graph structure
   [[maybe_unused]] void print() const {
     llvm::dbgs() << "SLP Graph Structure:\n";
@@ -243,6 +252,7 @@ class SLPGraph {
 
 private:
   SmallVector<std::unique_ptr<SLPGraphNode>> nodes;
+  llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
 };
 
 /// This pass implements the SLP vectorizer. It detects consecutive operations
@@ -272,9 +282,7 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
 }
 
 struct OperationsFingerprint {
-  OperationsFingerprint(
-      const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode)
-      : opToNode(opToNode) {}
+  OperationsFingerprint(const SLPGraph &graph) : graph(graph) {}
 
   Fingerprint getFingerprint(Operation *op) {
     auto it = fingerprints.find(op);
@@ -287,7 +295,7 @@ struct OperationsFingerprint {
     while (!worklist.empty()) {
       Operation *op = worklist.pop_back_val();
       toposortedOps.emplace_back(op);
-      if (opToNode.contains(op))
+      if (graph.getNodeForOp(op))
         continue;
 
       for (Value operand : op->getOperands()) {
@@ -310,9 +318,9 @@ struct OperationsFingerprint {
         if (!defOp)
           continue;
 
-        auto it1 = opToNode.find(defOp);
-        if (it1 != opToNode.end()) {
-          addDataToHash(hasher, it1->second);
+        auto *node = graph.getNodeForOp(defOp);
+        if (node) {
+          addDataToHash(hasher, node);
           continue;
         }
 
@@ -333,7 +341,7 @@ struct OperationsFingerprint {
       fingerprints.clear();
   }
 
-  const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode;
+  const SLPGraph &graph;
   DenseMap<Operation *, Fingerprint> fingerprints;
 };
 
@@ -355,16 +363,12 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
   LLVM_DEBUG(llvm::dbgs() << "=== Building SLP graph from " << rootGroups.size()
                           << " root groups ===\n");
   SLPGraph graph;
-  llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
 
   SmallVector<SLPGraphNode *> worklist;
 
   // First, create nodes for each contiguous memory operation group
   for (const auto &group : rootGroups) {
     auto *node = graph.addRoot(group.ops);
-    for (Operation *op : group.ops)
-      opToNode[op] = node;
-
     worklist.push_back(node);
 
     LLVM_DEBUG({
@@ -374,16 +378,16 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     });
   }
 
-  OperationsFingerprint fingerprints(opToNode);
+  OperationsFingerprint fingerprints(graph);
 
   auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
     Operation *user = use.getOwner();
-    auto it = opToNode.find(user);
-    if (it != opToNode.end()) {
+    auto *existingNode = graph.getNodeForOp(user);
+    if (existingNode) {
       LLVM_DEBUG(llvm::dbgs()
                  << "  Adding edge from " << node->ops.front()->getName()
-                 << " to " << it->first->getName() << "\n");
-      graph.addEdge(node, it->second);
+                 << " to " << user->getName() << "\n");
+      graph.addEdge(node, existingNode);
       return;
     }
 
@@ -420,7 +424,6 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     auto *newNode = graph.addNode(currentOps);
     graph.addEdge(node, newNode);
     for (Operation *op : currentOps) {
-      opToNode[op] = newNode;
       fingerprints.invalidate(op);
     }
 

>From 04b2316c9859d734af8f968072eb298bbf96cfd9 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 14:59:20 +0200
Subject: [PATCH 12/28] toposort

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 89 ++++++++++++++++++-
 1 file changed, 85 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e3b39ba10373c..8f0137a12d07b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -135,8 +135,8 @@ extractContiguousGroups(const MemoryOpGroup &group) {
 /// A node in the SLP graph representing a group of vectorizable operations
 struct SLPGraphNode {
   SmallVector<Operation *> ops;
-  llvm::SmallDenseSet<SLPGraphNode *> users;
-  llvm::SmallDenseSet<SLPGraphNode *> operands;
+  SmallVector<SLPGraphNode *> users;
+  SmallVector<SLPGraphNode *> operands;
   bool isRoot = false;
 
   SLPGraphNode() = default;
@@ -174,8 +174,8 @@ class SLPGraph {
 
   /// Add a dependency edge between nodes
   void addEdge(SLPGraphNode *from, SLPGraphNode *to) {
-    from->users.insert(to);
-    to->operands.insert(from);
+    from->users.push_back(to);
+    to->operands.push_back(from);
   }
 
   /// Get all root nodes
@@ -193,6 +193,80 @@ class SLPGraph {
     return it != opToNode.end() ? it->second : nullptr;
   }
 
+  /// Topologically sort the nodes in the graph
+  SmallVector<SLPGraphNode *> topologicalSort() const {
+    SmallVector<SLPGraphNode *> result;
+    llvm::SmallDenseSet<SLPGraphNode *> visited;
+
+    SmallVector<SLPGraphNode *> stack;
+
+    // Process each node
+    for (const auto &node : nodes) {
+      if (visited.contains(node.get()))
+        continue;
+
+      stack.emplace_back(node.get());
+      while (!stack.empty()) {
+        SLPGraphNode *node = stack.pop_back_val();
+        if (visited.contains(node))
+          continue;
+
+        stack.push_back(node);
+
+        bool pushed = false;
+        for (SLPGraphNode *operand : node->operands) {
+          if (visited.contains(operand))
+            continue;
+
+          stack.push_back(operand);
+          pushed = true;
+        }
+
+        if (!pushed) {
+          visited.insert(node);
+          result.push_back(node);
+        }
+      }
+    }
+
+    return result;
+  }
+
+  /// Vectorize the operations in the graph
+  LogicalResult vectorize(IRRewriter &rewriter) {
+    if (nodes.empty())
+      return success();
+
+    LLVM_DEBUG(llvm::dbgs()
+               << "Vectorizing SLP graph with " << nodes.size() << " nodes\n");
+
+    // Get topologically sorted nodes
+    SmallVector<SLPGraphNode *> sortedNodes = topologicalSort();
+    if (sortedNodes.empty()) {
+      LLVM_DEBUG(llvm::dbgs() << "Failed to topologically sort nodes\n");
+      return failure();
+    }
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "Topologically sorted nodes:\n";
+      for (auto *node : sortedNodes) {
+        llvm::dbgs() << "  Node with " << node->ops.size()
+                     << " operations: " << node->ops.front()->getName() << "\n";
+      }
+    });
+
+    // TODO: Implement vectorization logic:
+    // 1. Process nodes in topological order
+    // 2. For each node:
+    //    a. Check if all operands are vectorized
+    //    b. Create vector operation
+    //    c. Replace scalar operations with vector operation
+    // 3. Handle memory operations (loads/stores) specially
+    // 4. Update use-def chains
+
+    return success();
+  }
+
   /// Print the graph structure
   [[maybe_unused]] void print() const {
     llvm::dbgs() << "SLP Graph Structure:\n";
@@ -513,6 +587,13 @@ void SLPVectorizerPass::runOnOperation() {
 
     // Print the graph structure
     LLVM_DEBUG(graph.print());
+
+    // Vectorize the graph
+    IRRewriter rewriter(&getContext());
+    if (failed(graph.vectorize(rewriter))) {
+      LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
+      return signalPassFailure();
+    }
   });
 }
 

>From 740011b4c25913e89015076e84b583221c6c047d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:30:01 +0200
Subject: [PATCH 13/28] codegen

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 68 +++++++++++++++----
 1 file changed, 56 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 8f0137a12d07b..095ad4f11a91a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -132,6 +132,10 @@ extractContiguousGroups(const MemoryOpGroup &group) {
   return result;
 }
 
+static bool isVectorizable(Operation *op) {
+  return OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1;
+}
+
 /// A node in the SLP graph representing a group of vectorizable operations
 struct SLPGraphNode {
   SmallVector<Operation *> ops;
@@ -255,14 +259,58 @@ class SLPGraph {
       }
     });
 
-    // TODO: Implement vectorization logic:
-    // 1. Process nodes in topological order
-    // 2. For each node:
-    //    a. Check if all operands are vectorized
-    //    b. Create vector operation
-    //    c. Replace scalar operations with vector operation
-    // 3. Handle memory operations (loads/stores) specially
-    // 4. Update use-def chains
+    IRMapping mapping;
+    for (auto *node : sortedNodes) {
+      if (node->users.empty() && node->operands.empty())
+        continue;
+
+      Operation *op = node->ops.front();
+      rewriter.setInsertionPoint(op);
+      Location loc = op->getLoc();
+      int64_t numElements = node->ops.size();
+      if (auto load = dyn_cast<memref::LoadOp>(op)) {
+        auto vecType =
+            VectorType::get(numElements, load.getMemRefType().getElementType());
+        Value result = rewriter.create<vector::LoadOp>(
+            loc, vecType, load.getMemRef(), load.getIndices());
+        mapping.map(load.getResult(), result);
+      } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
+        Value val = mapping.lookupOrDefault(store.getValueToStore());
+        rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
+                                         store.getIndices());
+      } else if (isVectorizable(op)) {
+        auto vecType =
+            VectorType::get(numElements, op->getResultTypes().front());
+        for (auto [i, operand] : llvm::enumerate(op->getOperands())) {
+          if (getNodeForOp(operand.getDefiningOp()))
+            continue;
+
+          SmallVector<Value> args;
+          for (Operation *defOp : node->ops)
+            args.push_back(defOp->getOperand(i));
+
+          Value vector =
+              rewriter.create<vector::FromElementsOp>(loc, vecType, args);
+          mapping.map(operand, vector);
+        }
+
+        Operation *newOp = rewriter.clone(*op, mapping);
+        auto resVectorType =
+            VectorType::get(numElements, op->getResultTypes().front());
+        newOp->getResult(0).setType(resVectorType);
+
+        mapping.map(op->getResults(), newOp->getResults());
+      } else {
+        op->emitError("unsupported operation");
+        return failure();
+      }
+    }
+
+    for (auto *node : llvm::reverse(sortedNodes)) {
+      for (Operation *op : node->ops) {
+        rewriter.eraseOp(op);
+      }
+    }
 
     return success();
   }
@@ -343,10 +391,6 @@ struct SLPVectorizerPass
   SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
 };
 
-static bool isVectorizable(Operation *op) {
-  return OpTrait::hasElementwiseMappableTraits(op);
-}
-
 using Fingerprint = std::array<uint8_t, 20>;
 
 template <typename T>

>From bbd1122dd97f7719b3a75329fae58d8c72916f90 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:53:41 +0200
Subject: [PATCH 14/28] test

---
 mlir/test/Dialect/Vector/slp-vectorize.mlir | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index a07dd05dd16aa..266008e53ea43 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,10 +1,13 @@
 // RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
 
-// CHECK-LABEL: func @basic_slp
-func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
-  // CHECK: vector.transfer_read
-  // CHECK: arith.addi
-  // CHECK: vector.transfer_write
+// CHECK-LABEL: func @read_read_add_write
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index

>From bcbb729dea49391c7f0b0ead2e23198e6fa0b816 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:57:06 +0200
Subject: [PATCH 15/28] test

---
 mlir/test/Dialect/Vector/slp-vectorize.mlir | 25 +++++++++++++++++++++
 1 file changed, 25 insertions(+)

diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 266008e53ea43..28a255f90a869 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,5 +1,30 @@
 // RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
 
+// CHECK-LABEL: func @read_write
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  memref.store %0, %arg0[%c0] : memref<8xi32>
+  memref.store %1, %arg0[%c1] : memref<8xi32>
+  memref.store %2, %arg0[%c2] : memref<8xi32>
+  memref.store %3, %arg0[%c3] : memref<8xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add_write
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {

>From 1f68586ad924bf8553438b40edfa6c44e2e2b017 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 16:25:45 +0200
Subject: [PATCH 16/28] fixes

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 49 +++++++++++++------
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 36 ++++++++++++++
 2 files changed, 70 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 095ad4f11a91a..a40131a1b10ff 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -264,24 +264,13 @@ class SLPGraph {
       if (node->users.empty() && node->operands.empty())
         continue;
 
+      int64_t numElements = node->ops.size();
       Operation *op = node->ops.front();
       rewriter.setInsertionPoint(op);
       Location loc = op->getLoc();
-      int64_t numElements = node->ops.size();
-      if (auto load = dyn_cast<memref::LoadOp>(op)) {
-        auto vecType =
-            VectorType::get(numElements, load.getMemRefType().getElementType());
-        Value result = rewriter.create<vector::LoadOp>(
-            loc, vecType, load.getMemRef(), load.getIndices());
-        mapping.map(load.getResult(), result);
-      } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
-        Value val = mapping.lookupOrDefault(store.getValueToStore());
-        rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
-                                         store.getIndices());
-      } else if (isVectorizable(op)) {
-        auto vecType =
-            VectorType::get(numElements, op->getResultTypes().front());
-        for (auto [i, operand] : llvm::enumerate(op->getOperands())) {
+
+      auto handleNonVectorInputs = [&](ValueRange operands) {
+        for (auto [i, operand] : llvm::enumerate(operands)) {
           if (getNodeForOp(operand.getDefiningOp()))
             continue;
 
@@ -289,17 +278,47 @@ class SLPGraph {
           for (Operation *defOp : node->ops)
             args.push_back(defOp->getOperand(i));
 
+          auto vecType = VectorType::get(numElements, operand.getType());
           Value vector =
               rewriter.create<vector::FromElementsOp>(loc, vecType, args);
           mapping.map(operand, vector);
         }
+      };
+
+      auto handleNonVectorOutputs = [&](Value newResult) {
+        for (auto [i, result] : llvm::enumerate(node->ops)) {
+          for (OpOperand &use : result->getUses()) {
+            Operation *useOwner = use.getOwner();
+            if (getNodeForOp(useOwner))
+              continue;
+
+            Value elem = rewriter.create<vector::ExtractOp>(loc, newResult, i);
+            use.set(elem);
+          }
+        }
+      };
 
+      if (auto load = dyn_cast<memref::LoadOp>(op)) {
+        auto vecType =
+            VectorType::get(numElements, load.getMemRefType().getElementType());
+        Value result = rewriter.create<vector::LoadOp>(
+            loc, vecType, load.getMemRef(), load.getIndices());
+        mapping.map(load.getResult(), result);
+        handleNonVectorOutputs(result);
+      } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
+        handleNonVectorInputs(store.getValueToStore());
+        Value val = mapping.lookupOrDefault(store.getValueToStore());
+        rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
+                                         store.getIndices());
+      } else if (isVectorizable(op)) {
+        handleNonVectorInputs(op->getOperands());
         Operation *newOp = rewriter.clone(*op, mapping);
         auto resVectorType =
             VectorType::get(numElements, op->getResultTypes().front());
         newOp->getResult(0).setType(resVectorType);
 
         mapping.map(op->getResults(), newOp->getResults());
+        handleNonVectorOutputs(newOp->getResult(0));
       } else {
         op->emitError("unsupported operation");
         return failure();
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 28a255f90a869..2b2b91d667e00 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
 
+
 // CHECK-LABEL: func @read_write
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
@@ -24,6 +25,41 @@ func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
   return
 }
 
+// CHECK-LABEL: func @read_read_add
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     %[[R0:.*]] = vector.extract %[[RES]][0] : i32 from vector<4xi32>
+  // CHECK:     %[[R1:.*]] = vector.extract %[[RES]][1] : i32 from vector<4xi32>
+  // CHECK:     %[[R2:.*]] = vector.extract %[[RES]][2] : i32 from vector<4xi32>
+  // CHECK:     %[[R3:.*]] = vector.extract %[[RES]][3] : i32 from vector<4xi32>
+  // CHECK:     return %[[R0]], %[[R1]], %[[R2]], %[[R3]] : i32, i32, i32, i32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+  %8 = arith.addi %0, %4 : i32
+  %9 = arith.addi %1, %5 : i32
+  %10 = arith.addi %2, %6 : i32
+  %11 = arith.addi %3, %7 : i32
+
+  return %8, %9, %10, %11 : i32, i32, i32, i32
+}
+
 
 // CHECK-LABEL: func @read_read_add_write
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)

>From 7b24debf6d7e8e095dc60fc5b3bbcbbce19e27dd Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 17:41:39 +0200
Subject: [PATCH 17/28] fixes

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 56 +++++++++++++++++--
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 29 ++++++++++
 2 files changed, 79 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index a40131a1b10ff..ab0b3f549192f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -259,9 +259,13 @@ class SLPGraph {
       }
     });
 
+    auto isGoodNode = [&](SLPGraphNode *node) {
+      return node->users.empty() && node->operands.empty();
+    };
+
     IRMapping mapping;
     for (auto *node : sortedNodes) {
-      if (node->users.empty() && node->operands.empty())
+      if (isGoodNode(node))
         continue;
 
       int64_t numElements = node->ops.size();
@@ -326,6 +330,9 @@ class SLPGraph {
     }
 
     for (auto *node : llvm::reverse(sortedNodes)) {
+      if (isGoodNode(node))
+        continue;
+
       for (Operation *op : node->ops) {
         rewriter.eraseOp(op);
       }
@@ -560,10 +567,47 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
 
     auto *newNode = graph.addNode(currentOps);
     graph.addEdge(node, newNode);
-    for (Operation *op : currentOps) {
+    for (Operation *op : currentOps)
       fingerprints.invalidate(op);
+
+    worklist.push_back(newNode);
+  };
+
+  auto processOperands = [&](SLPGraphNode *node, Value operand, int64_t index) {
+    Operation *srcOp = operand.getDefiningOp();
+    if (!srcOp)
+      return;
+
+    auto *existingNode = graph.getNodeForOp(srcOp);
+    if (existingNode) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "  Adding edge from " << srcOp->getName() << " to "
+                 << node->ops.front()->getName() << "\n");
+      graph.addEdge(existingNode, node);
+      return;
+    }
+
+    if (!isVectorizable(srcOp))
+      return;
+
+    SmallVector<Operation *> currentOps;
+    currentOps.emplace_back(srcOp);
+    for (Operation *op : ArrayRef(node->ops).drop_front()) {
+      Operation *otherOp = op->getOperand(index).getDefiningOp();
+      if (!otherOp || !isEquivalent(otherOp, srcOp))
+        break;
+
+      currentOps.push_back(otherOp);
     }
 
+    if (currentOps.size() == 1)
+      return;
+
+    auto *newNode = graph.addNode(currentOps);
+    graph.addEdge(newNode, node);
+    for (Operation *op : currentOps)
+      fingerprints.invalidate(op);
+
     worklist.push_back(newNode);
   };
 
@@ -574,11 +618,11 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
                             << node->ops.front()->getName() << "\n");
 
     Operation *op = node->ops.front();
-    for (OpOperand &use : op->getUses()) {
+    for (OpOperand &use : op->getUses())
       processUse(node, use);
-      LLVM_DEBUG(llvm::dbgs() << "  Processing use in operation: "
-                              << use.getOwner()->getName() << "\n");
-    }
+
+    for (auto [i, operand] : llvm::enumerate(op->getOperands()))
+      processOperands(node, operand, i);
   }
 
   return graph;
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 2b2b91d667e00..036e1fcbed5d5 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -60,6 +60,35 @@ func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i3
   return %8, %9, %10, %11 : i32, i32, i32, i32
 }
 
+// CHECK-LABEL: func @add_write
+//  CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: memref<8xi32>)
+func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
+                     %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32,
+                     %arg8: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] : vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.from_elements %[[ARG4]], %[[ARG5]], %[[ARG6]], %[[ARG7]] : vector<4xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %0, %1 : vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG8]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %8 = arith.addi %arg0, %arg4 : i32
+  %9 = arith.addi %arg1, %arg5 : i32
+  %10 = arith.addi %arg2, %arg6 : i32
+  %11 = arith.addi %arg3, %arg7 : i32
+
+  memref.store %8, %arg8[%c0] : memref<8xi32>
+  memref.store %9, %arg8[%c1] : memref<8xi32>
+  memref.store %10, %arg8[%c2] : memref<8xi32>
+  memref.store %11, %arg8[%c3] : memref<8xi32>
+
+  return
+}
+
+
 
 // CHECK-LABEL: func @read_read_add_write
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)

>From 5433fd954da2764b8ed92de5e42cac0d903f125a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:07:37 +0200
Subject: [PATCH 18/28] handle size mismatch

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 21 +++++++
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 62 ++++++++++++++++++-
 2 files changed, 82 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ab0b3f549192f..f54a9aba0e6c0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -302,6 +302,16 @@ class SLPGraph {
         }
       };
 
+      auto handleVecSizeMismatch = [&](Value arg) -> Value {
+        auto srcType = cast<VectorType>(arg.getType());
+        assert(srcType.getRank() == 1);
+        if (srcType.getDimSize(0) == numElements)
+          return arg;
+
+        return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, 0,
+                                                              numElements, 1);
+      };
+
       if (auto load = dyn_cast<memref::LoadOp>(op)) {
         auto vecType =
             VectorType::get(numElements, load.getMemRefType().getElementType());
@@ -312,6 +322,7 @@ class SLPGraph {
       } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
         handleNonVectorInputs(store.getValueToStore());
         Value val = mapping.lookupOrDefault(store.getValueToStore());
+        val = handleVecSizeMismatch(val);
         rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
                                          store.getIndices());
       } else if (isVectorizable(op)) {
@@ -319,6 +330,15 @@ class SLPGraph {
         Operation *newOp = rewriter.clone(*op, mapping);
         auto resVectorType =
             VectorType::get(numElements, op->getResultTypes().front());
+
+        {
+          OpBuilder::InsertionGuard guard(rewriter);
+          rewriter.setInsertionPoint(newOp);
+          for (OpOperand &operand : newOp->getOpOperands()) {
+            Value newOperand = handleVecSizeMismatch(operand.get());
+            operand.set(newOperand);
+          }
+        }
         newOp->getResult(0).setType(resVectorType);
 
         mapping.map(op->getResults(), newOp->getResults());
@@ -701,6 +721,7 @@ void SLPVectorizerPass::runOnOperation() {
       LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
       return signalPassFailure();
     }
+    op->dump();
   });
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 036e1fcbed5d5..76592833a78b4 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -25,6 +25,31 @@ func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
   return
 }
 
+
+// CHECK-LABEL: func @read_write_size_mistamtch
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write_size_mistamtch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[RES1:.*]] = vector.extract_strided_slice %[[RES]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK:     vector.store %[[RES1]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  memref.store %0, %arg0[%c0] : memref<8xi32>
+  memref.store %1, %arg0[%c1] : memref<8xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
@@ -60,6 +85,7 @@ func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i3
   return %8, %9, %10, %11 : i32, i32, i32, i32
 }
 
+
 // CHECK-LABEL: func @add_write
 //  CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: memref<8xi32>)
 func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
@@ -89,7 +115,6 @@ func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
 }
 
 
-
 // CHECK-LABEL: func @read_read_add_write
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
@@ -125,3 +150,38 @@ func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
 
   return
 }
+
+
+// CHECK-LABEL: func @read_read_add_write_size_mismatch
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[A1:.*]] = vector.extract_strided_slice %[[A]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK:     %[[B1:.*]] = vector.extract_strided_slice %[[B]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %[[A1]], %[[B1]] : vector<2xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+  %8 = arith.addi %0, %4 : i32
+  %9 = arith.addi %1, %5 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+
+  return
+}

>From 2f02d807ac75c95770d1ff72a082c69024616f2c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:53:54 +0200
Subject: [PATCH 19/28] adjacent indices

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 98 ++++++++++---------
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 25 +++++
 2 files changed, 79 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index f54a9aba0e6c0..cc252a0e32c06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,16 +52,43 @@ struct MemoryOpGroup {
   bool empty() const { return ops.empty(); }
 };
 
-// Helper function to extract base and index from a memory operation
-std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
-  if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
-    if (auto value = getConstantIntValue(loadOp.getIndices().front()))
-      return std::make_pair(loadOp.getMemRef(), *value);
-  } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
-    if (auto value = getConstantIntValue(storeOp.getIndices().front()))
-      return std::make_pair(storeOp.getMemRef(), *value);
+static ValueRange getIndices(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndices();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getIndices();
+  return {};
+}
+
+static Type getElementType(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getResult().getType();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getValueToStore().getType();
+  return {};
+}
+
+static bool isAdjacentIndices(Value idx1, Value idx2) {
+  if (auto c1 = getConstantIntValue(idx1)) {
+    if (auto c2 = getConstantIntValue(idx2))
+      return *c1 + 1 == *c2;
   }
-  return std::nullopt;
+  return false;
+}
+
+static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
+  if (idx1.empty() || idx1.size() != idx2.size())
+    return false;
+
+  if (idx1.drop_back() != idx2.drop_back())
+    return false;
+
+  return isAdjacentIndices(idx1.back(), idx2.back());
+}
+
+static bool isAdjacentIndices(Operation *op1, Operation *op2) {
+  return getElementType(op1) == getElementType(op2) &&
+         isAdjacentIndices(getIndices(op1), getIndices(op2));
 }
 
 // Extract contiguous groups from a MemoryOpGroup
@@ -71,64 +98,48 @@ extractContiguousGroups(const MemoryOpGroup &group) {
   if (group.ops.empty())
     return result;
 
-  // Keep track of which operations we've processed
-  DenseSet<Operation *> processedOps;
+  llvm::SmallDenseSet<Operation *> processedOps;
 
-  // Process each operation
   for (Operation *op : group.ops) {
-    // Skip if we've already processed this operation
     if (processedOps.contains(op))
       continue;
 
-    // Get base and index of current operation
-    auto baseAndIndex = getBaseAndIndex(op);
-    if (!baseAndIndex)
-      continue;
-
-    auto [base, index] = *baseAndIndex;
-
     // Start a new group with this operation
     result.emplace_back(group.type);
     MemoryOpGroup &currentGroup = result.back();
-    currentGroup.ops.push_back(op);
+    auto &currentOps = currentGroup.ops;
+    currentOps.push_back(op);
     processedOps.insert(op);
 
-    LLVM_DEBUG(llvm::dbgs() << "Starting new group at base " << base
-                            << " index " << index << "\n");
-
-    // Try to find operations with adjacent indices
     bool foundMore;
     do {
       foundMore = false;
-      // Look for operations with index+1
       for (Operation *otherOp : group.ops) {
         if (processedOps.contains(otherOp))
           continue;
 
-        auto otherBaseAndIndex = getBaseAndIndex(otherOp);
-        if (!otherBaseAndIndex)
-          continue;
-
-        auto [otherBase, otherIndex] = *otherBaseAndIndex;
-
-        // Check if this operation has the same base and adjacent index
-        if (otherBase == base && otherIndex == currentGroup.ops.size()) {
-          currentGroup.ops.push_back(otherOp);
+        Operation *firstOp = currentOps.front();
+        Operation *lastOp = currentOps.back();
+        if (isAdjacentIndices(otherOp, firstOp)) {
+          currentOps.insert(currentOps.begin(), otherOp);
+          processedOps.insert(otherOp);
+          foundMore = true;
+        } else if (isAdjacentIndices(lastOp, otherOp)) {
+          currentOps.push_back(otherOp);
           processedOps.insert(otherOp);
           foundMore = true;
-          LLVM_DEBUG(llvm::dbgs()
-                     << "Added operation with index " << otherIndex << "\n");
-          break;
         }
       }
     } while (foundMore);
-  }
 
-  // Remove empty groups
-  result.erase(std::remove_if(result.begin(), result.end(),
-                              [](const MemoryOpGroup &g) { return g.empty(); }),
-               result.end());
+    if (currentOps.size() <= 1) {
+      result.pop_back();
+      continue;
+    }
 
+    LLVM_DEBUG(llvm::dbgs() << "Extracted contiguous group with "
+                            << currentGroup.ops.size() << " operations\n");
+  }
   return result;
 }
 
@@ -721,7 +732,6 @@ void SLPVectorizerPass::runOnOperation() {
       LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
       return signalPassFailure();
     }
-    op->dump();
   });
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 76592833a78b4..6be405ad078b9 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -50,6 +50,31 @@ func.func @read_write_size_mistamtch(%arg0: memref<8xi32>, %arg1: memref<8xi32>)
 }
 
 
+// CHECK-LABEL: func @read_write_interleaved
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+
+  memref.store %1, %arg0[%c1] : memref<8xi32>
+  memref.store %0, %arg0[%c0] : memref<8xi32>
+  memref.store %3, %arg0[%c3] : memref<8xi32>
+  memref.store %2, %arg0[%c2] : memref<8xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){

>From 019f5614f606a9f5d031367de77de692819b0efc Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:57:52 +0200
Subject: [PATCH 20/28] test

---
 mlir/test/Dialect/Vector/slp-vectorize.mlir | 38 +++++++++++++++++++++
 1 file changed, 38 insertions(+)

diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 6be405ad078b9..9c5005f807c71 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -210,3 +210,41 @@ func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref
 
   return
 }
+
+
+// CHECK-LABEL: func @read_read_add_write_interleaved
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+  %11 = arith.addi %3, %7 : i32
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %8 = arith.addi %0, %4 : i32
+
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %10 = arith.addi %2, %6 : i32
+
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %9 = arith.addi %1, %5 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+
+  return
+}

>From fc5d42c732bc9632123925e8d152d0fff23e8813 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:07:26 +0200
Subject: [PATCH 21/28] fixes and test

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 11 +++-
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 54 +++++++++++++++++++
 2 files changed, 64 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index cc252a0e32c06..3ff46093d9fbe 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,6 +52,14 @@ struct MemoryOpGroup {
   bool empty() const { return ops.empty(); }
 };
 
+static Value getBase(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getMemRef();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getMemRef();
+  return {};
+}
+
 static ValueRange getIndices(Operation *op) {
   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
     return loadOp.getIndices();
@@ -87,7 +95,8 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
 }
 
 static bool isAdjacentIndices(Operation *op1, Operation *op2) {
-  return getElementType(op1) == getElementType(op2) &&
+  return getBase(op1) == getBase(op2) &&
+         getElementType(op1) == getElementType(op2) &&
          isAdjacentIndices(getIndices(op1), getIndices(op2));
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 9c5005f807c71..820fbf2d260cd 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -248,3 +248,57 @@ func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8
 
   return
 }
+
+
+// CHECK-LABEL: func @read_read_add_add_write
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>
+//  CHECK-SAME: , %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
+                                   %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[ADD1:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+  // CHECK:     %[[C:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] : vector<4xi32>
+  // CHECK:     %[[ADD2:.*]] = arith.addi %[[A]], %[[C]] : vector<4xi32>
+  // CHECK:     vector.store %[[ADD1]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     vector.store %[[ADD2]], %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+  %8 = arith.addi %0, %4 : i32
+  %12 = arith.addi %0, %arg2 : i32
+
+  %13 = arith.addi %1, %arg3 : i32
+  %9 = arith.addi %1, %5 : i32
+
+  %10 = arith.addi %2, %6 : i32
+  %14 = arith.addi %2, %arg4 : i32
+
+  %15 = arith.addi %3, %arg5 : i32
+  %11 = arith.addi %3, %7 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
+
+  memref.store %12, %arg1[%c0] : memref<8xi32>
+  memref.store %13, %arg1[%c1] : memref<8xi32>
+  memref.store %14, %arg1[%c2] : memref<8xi32>
+  memref.store %15, %arg1[%c3] : memref<8xi32>
+
+  return
+}

>From e7e1172787b2697e7ad9860849b48a229d9e73bd Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:22:03 +0200
Subject: [PATCH 22/28] better side effects handling

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 94 +++++++++++--------
 1 file changed, 55 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3ff46093d9fbe..6cb6faa486702 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,6 +52,61 @@ struct MemoryOpGroup {
   bool empty() const { return ops.empty(); }
 };
 
+static bool isReadOp(Operation *op) {
+  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!effectInterface)
+    return true;
+
+  return effectInterface.hasEffect<MemoryEffects::Read>();
+}
+
+static bool isWriteOp(Operation *op) {
+  auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!effectInterface)
+    return true;
+
+  return effectInterface.hasEffect<MemoryEffects::Write>();
+}
+
+/// Collect all memory operations in the block into groups.
+/// Each group contains either all loads or all stores, uninterrupted by
+/// operations of the other type.
+static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
+  SmallVector<MemoryOpGroup> groups;
+  MemoryOpGroup *currentGroup = nullptr;
+
+  for (Operation &op : block) {
+    if (currentGroup) {
+      if (currentGroup->isLoadGroup() && isWriteOp(&op)) {
+        currentGroup = nullptr;
+      } else if (currentGroup->isStoreGroup() && isReadOp(&op)) {
+        currentGroup = nullptr;
+      }
+    }
+
+    if (!isa<memref::LoadOp, memref::StoreOp>(op))
+      continue;
+
+    bool isLoad = isReadOp(&op);
+    MemoryOpGroup::Type type =
+        isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
+
+    if (!currentGroup) {
+      groups.emplace_back(type);
+      currentGroup = &groups.back();
+    }
+
+    currentGroup->ops.push_back(&op);
+  }
+
+  // Remove empty groups
+  groups.erase(std::remove_if(groups.begin(), groups.end(),
+                              [](const MemoryOpGroup &g) { return g.empty(); }),
+               groups.end());
+
+  return groups;
+}
+
 static Value getBase(Operation *op) {
   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
     return loadOp.getMemRef();
@@ -449,12 +504,6 @@ class SLPGraph {
 struct SLPVectorizerPass
     : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
   void runOnOperation() override;
-
-private:
-  /// Collect all memory operations in the block into groups.
-  /// Each group contains either all loads or all stores, uninterrupted by
-  /// operations of the other type.
-  SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
 };
 
 using Fingerprint = std::array<uint8_t, 20>;
@@ -668,39 +717,6 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
   return graph;
 }
 
-SmallVector<MemoryOpGroup>
-SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
-  SmallVector<MemoryOpGroup> groups;
-  MemoryOpGroup *currentGroup = nullptr;
-
-  for (Operation &op : block) {
-    // Skip non-memory operations
-    if (!isa<memref::LoadOp, memref::StoreOp>(op))
-      continue;
-
-    bool isLoad = isa<memref::LoadOp>(op);
-    MemoryOpGroup::Type type =
-        isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
-
-    // Start a new group if:
-    // 1. We don't have a current group, or
-    // 2. The current operation is a different type than the current group
-    if (!currentGroup || currentGroup->type != type) {
-      groups.emplace_back(type);
-      currentGroup = &groups.back();
-    }
-
-    currentGroup->ops.push_back(&op);
-  }
-
-  // Remove empty groups
-  groups.erase(std::remove_if(groups.begin(), groups.end(),
-                              [](const MemoryOpGroup &g) { return g.empty(); }),
-               groups.end());
-
-  return groups;
-}
-
 void SLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
 

>From ae187a0dada44517f56efe5eaa4ce981f8899dfc Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:44:47 +0200
Subject: [PATCH 23/28] cleanup

---
 .../mlir/Dialect/Vector/Transforms/Passes.h   |  3 --
 .../mlir/Dialect/Vector/Transforms/Passes.td  | 14 ++++--
 .../Vector/Transforms/SLPVectorizer.cpp       | 49 ++++++++++++-------
 mlir/test/Dialect/Vector/slp-vectorize.mlir   |  2 +-
 4 files changed, 43 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 43112f084dc60..5667f4fa95ace 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -25,9 +25,6 @@ std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
     VectorMultiReductionLowering option =
         VectorMultiReductionLowering::InnerParallel);
 
-/// Creates a pass that implements the SLP vectorizer.
-std::unique_ptr<Pass> createSLPVectorizerPass();
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 94ccd61cb5170..d5c31c9f78409 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,15 +34,21 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
   ];
 }
 
-def SLPVectorizer : Pass<"slp-vectorizer", "ModuleOp"> {
+def GreedySLPVectorizer : Pass<"greedy-slp-vectorizer"> {
   let summary = "SLP Vectorizer Pass";
   let description = [{
     This pass implements the SLP (Superword Level Parallelism) vectorizer.
     It detects consecutive operations that can be put together into vector
-    operations. The pass works bottom-up, across basic blocks, in search of
-    scalars to combine.
+    operations. The pass works bi-directionaly, starting from reads or stores,
+    in search of scalars to combine.
+
+    This is greedy vectorizer, it doesn't have any cost model (yet) and it tries
+    to create vector ops if we have at least 2 potential ops.
+
+    It doesn't check if target actually supports resulted vectors either, user
+    will need a follow up pass which will split large and/or unaliggned vectors
+    into sizes actually supported by the target.
   }];
-  let constructor = "mlir::vector::createSLPVectorizerPass()";
   let dependentDialects = ["mlir::vector::VectorDialect"];
 }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 6cb6faa486702..d7c2dc3845cac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -27,7 +27,7 @@
 
 namespace mlir {
 namespace vector {
-#define GEN_PASS_DEF_SLPVECTORIZER
+#define GEN_PASS_DEF_GREEDYSLPVECTORIZER
 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
 } // namespace vector
 } // namespace mlir
@@ -115,6 +115,19 @@ static Value getBase(Operation *op) {
   return {};
 }
 
+static bool isContiguousLastDim(Value val) {
+  auto memrefType = dyn_cast<MemRefType>(val.getType());
+  if (!memrefType)
+    return false;
+
+  int64_t offset;
+  SmallVector<int64_t> strides;
+  if (failed(memrefType.getStridesAndOffset(strides, offset)))
+    return false;
+
+  return !strides.empty() && strides.back() == 1;
+}
+
 static ValueRange getIndices(Operation *op) {
   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
     return loadOp.getIndices();
@@ -150,8 +163,15 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
 }
 
 static bool isAdjacentIndices(Operation *op1, Operation *op2) {
-  return getBase(op1) == getBase(op2) &&
-         getElementType(op1) == getElementType(op2) &&
+  Value base1 = getBase(op1);
+  Value base2 = getBase(op2);
+  if (base1 != base2)
+    return false;
+
+  if (!isContiguousLastDim(base1))
+    return false;
+
+  return getElementType(op1) == getElementType(op2) &&
          isAdjacentIndices(getIndices(op1), getIndices(op2));
 }
 
@@ -498,11 +518,9 @@ class SLPGraph {
   llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
 };
 
-/// This pass implements the SLP vectorizer. It detects consecutive operations
-/// that can be put together into vector operations. The pass works bottom-up,
-/// across basic blocks, in search of scalars to combine.
-struct SLPVectorizerPass
-    : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
+struct GreedySLPVectorizerPass
+    : public mlir::vector::impl::GreedySLPVectorizerBase<
+          GreedySLPVectorizerPass> {
   void runOnOperation() override;
 };
 
@@ -717,11 +735,11 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
   return graph;
 }
 
-void SLPVectorizerPass::runOnOperation() {
+void GreedySLPVectorizerPass::runOnOperation() {
   Operation *op = getOperation();
 
   // Walk all blocks recursively
-  op->walk([&](Block *block) {
+  op->walk([&](Block *block) -> WalkResult {
     LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
                             << block->getParentOp()->getName() << "\n");
 
@@ -747,21 +765,18 @@ void SLPVectorizerPass::runOnOperation() {
 
     // Build the SLP graph from root groups
     SLPGraph graph = buildSLPGraph(rootGroups);
-
-    // Print the graph structure
     LLVM_DEBUG(graph.print());
 
     // Vectorize the graph
     IRRewriter rewriter(&getContext());
     if (failed(graph.vectorize(rewriter))) {
       LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
-      return signalPassFailure();
+      signalPassFailure();
+      return WalkResult::interrupt();
     }
+
+    return WalkResult::advance();
   });
 }
 
 } // namespace
-
-std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
-  return std::make_unique<SLPVectorizerPass>();
-}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 820fbf2d260cd..2e9298d11ed05 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
+// RUN: mlir-opt %s --greedy-slp-vectorizer | FileCheck %s
 
 
 // CHECK-LABEL: func @read_write

>From de5e898a81c1363acef5f5b00b9d9c254fa2554b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 20:10:02 +0200
Subject: [PATCH 24/28] cleanup

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 80 +++++++++++++------
 1 file changed, 57 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index d7c2dc3845cac..24059ec355b30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -49,7 +49,6 @@ struct MemoryOpGroup {
   bool isStoreGroup() const { return type == Type::Store; }
 
   size_t size() const { return ops.size(); }
-  bool empty() const { return ops.empty(); }
 };
 
 static bool isReadOp(Operation *op) {
@@ -99,11 +98,6 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
     currentGroup->ops.push_back(&op);
   }
 
-  // Remove empty groups
-  groups.erase(std::remove_if(groups.begin(), groups.end(),
-                              [](const MemoryOpGroup &g) { return g.empty(); }),
-               groups.end());
-
   return groups;
 }
 
@@ -144,14 +138,19 @@ static Type getElementType(Operation *op) {
   return {};
 }
 
+/// Check if two indices are consecutive, i.e fastest index differs by 1.
 static bool isAdjacentIndices(Value idx1, Value idx2) {
   if (auto c1 = getConstantIntValue(idx1)) {
     if (auto c2 = getConstantIntValue(idx2))
       return *c1 + 1 == *c2;
   }
+
+  // TODO: Check arith.add, affine.apply, etc
   return false;
 }
 
+/// Check if two ranges of indices are consecutive, i.e fastest index differs
+/// by 1 and all other indices are the same.
 static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
   if (idx1.empty() || idx1.size() != idx2.size())
     return false;
@@ -162,7 +161,10 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
   return isAdjacentIndices(idx1.back(), idx2.back());
 }
 
-static bool isAdjacentIndices(Operation *op1, Operation *op2) {
+/// Check if two operations are adjacent and can be combined into a vector op.
+/// This is done by checking if the base memrefs are the same, the last
+/// dimension is contiguous, and the element types and indices are compatible
+static bool isAdjacentOps(Operation *op1, Operation *op2) {
   Value base1 = getBase(op1);
   Value base2 = getBase(op2);
   if (base1 != base2)
@@ -195,6 +197,8 @@ extractContiguousGroups(const MemoryOpGroup &group) {
     currentOps.push_back(op);
     processedOps.insert(op);
 
+    // Keep adding ops to the beginning or end of the current group until no
+    // more ops can be added.
     bool foundMore;
     do {
       foundMore = false;
@@ -204,11 +208,11 @@ extractContiguousGroups(const MemoryOpGroup &group) {
 
         Operation *firstOp = currentOps.front();
         Operation *lastOp = currentOps.back();
-        if (isAdjacentIndices(otherOp, firstOp)) {
+        if (isAdjacentOps(otherOp, firstOp)) {
           currentOps.insert(currentOps.begin(), otherOp);
           processedOps.insert(otherOp);
           foundMore = true;
-        } else if (isAdjacentIndices(lastOp, otherOp)) {
+        } else if (isAdjacentOps(lastOp, otherOp)) {
           currentOps.push_back(otherOp);
           processedOps.insert(otherOp);
           foundMore = true;
@@ -222,7 +226,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
     }
 
     LLVM_DEBUG(llvm::dbgs() << "Extracted contiguous group with "
-                            << currentGroup.ops.size() << " operations\n");
+                            << currentGroup.size() << " operations\n");
   }
   return result;
 }
@@ -241,6 +245,8 @@ struct SLPGraphNode {
   SLPGraphNode() = default;
   SLPGraphNode(ArrayRef<Operation *> operations)
       : ops(operations.begin(), operations.end()) {}
+
+  size_t size() const { return ops.size(); }
 };
 
 /// A graph of vectorizable operations
@@ -349,7 +355,7 @@ class SLPGraph {
     LLVM_DEBUG({
       llvm::dbgs() << "Topologically sorted nodes:\n";
       for (auto *node : sortedNodes) {
-        llvm::dbgs() << "  Node with " << node->ops.size()
+        llvm::dbgs() << "  Node with " << node->size()
                      << " operations: " << node->ops.front()->getName() << "\n";
       }
     });
@@ -363,7 +369,7 @@ class SLPGraph {
       if (isGoodNode(node))
         continue;
 
-      int64_t numElements = node->ops.size();
+      int64_t numElements = node->size();
       Operation *op = node->ops.front();
       rewriter.setInsertionPoint(op);
       Location loc = op->getLoc();
@@ -467,15 +473,15 @@ class SLPGraph {
       if (!node->isRoot)
         continue;
       llvm::dbgs() << "  "
-                   << (isa<memref::LoadOp>(node->ops[0]) ? "LOAD" : "STORE")
-                   << " group with " << node->ops.size() << " operations:\n";
+                   << (isa<memref::LoadOp>(node->ops.front()) ? "LOAD"
+                                                              : "STORE")
+                   << " group with " << node->size() << " operations:\n";
       for (auto *op : node->ops) {
         llvm::dbgs() << "    " << *op << "\n";
       }
       llvm::dbgs() << "    Users: ";
       for (auto *user : node->users) {
-        llvm::dbgs() << "\n      Group with " << user->ops.size()
-                     << " operations:";
+        llvm::dbgs() << "\n      Group with " << user->size() << " operations:";
         for (auto *op : user->ops) {
           llvm::dbgs() << "\n        " << *op;
         }
@@ -488,13 +494,13 @@ class SLPGraph {
     for (const auto &node : nodes) {
       if (node->isRoot)
         continue;
-      llvm::dbgs() << "  Group with " << node->ops.size() << " operations:\n";
+      llvm::dbgs() << "  Group with " << node->size() << " operations:\n";
       for (auto *op : node->ops) {
         llvm::dbgs() << "    " << *op << "\n";
       }
       llvm::dbgs() << "    Operands: ";
       for (auto *operand : node->operands) {
-        llvm::dbgs() << "\n      Group with " << operand->ops.size()
+        llvm::dbgs() << "\n      Group with " << operand->size()
                      << " operations:";
         for (auto *op : operand->ops) {
           llvm::dbgs() << "\n        " << *op;
@@ -502,8 +508,7 @@ class SLPGraph {
       }
       llvm::dbgs() << "\n    Users: ";
       for (auto *user : node->users) {
-        llvm::dbgs() << "\n      Group with " << user->ops.size()
-                     << " operations:";
+        llvm::dbgs() << "\n      Group with " << user->size() << " operations:";
         for (auto *op : user->ops) {
           llvm::dbgs() << "\n        " << *op;
         }
@@ -518,6 +523,28 @@ class SLPGraph {
   llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
 };
 
+/// This pass implements the greedy SLP vectorizer. It detects consecutive
+/// operations that can be put together into vector operations. The pass works
+/// bi-directionaly, starting from reads or stores, in search of scalars to
+/// combine.
+///
+/// Pass is split into multiple steps:
+/// 1. Collect memory operation groups within same block.
+/// Group is either multiple loads uninterrupted by stores or multiple stores
+/// uninterrupted by loads.
+///
+/// 2. Extract contiguous groups from memory operation groups, based on the
+/// ops base memrefs, load/store element types, and indices.
+///
+/// 3. Build SLP graph from contiguous groups. This is done by going both
+/// top-down and bottom-up through uses/operands respectively, starting from
+/// contiguous memory operation groups.
+///
+/// 4. Vectorize SLP graph. This is done by topological sort of the graph and
+/// vectorizing each node in the order of the sort.
+///
+/// Vectorization is done by cloning the operations and mapping the operands and
+/// results.
 struct GreedySLPVectorizerPass
     : public mlir::vector::impl::GreedySLPVectorizerBase<
           GreedySLPVectorizerPass> {
@@ -532,6 +559,10 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
       ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
 }
 
+/// SLP vectorizer is bi-directional, so when we go top-down we can can have
+/// multiple users with the same immediate op type, this class tries to compute
+/// fingerprint for such ops based on the entire ops graph to maximize further
+/// scalar ops merging.
 struct OperationsFingerprint {
   OperationsFingerprint(const SLPGraph &graph) : graph(graph) {}
 
@@ -606,7 +637,8 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
   return true;
 }
 
-/// Build the SLP graph starting from memory operation groups
+/// Build the SLP graph starting from memory operation groups and going both
+/// top-down and bottom-up through uses/operands respectively.
 static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
   if (rootGroups.empty())
     return SLPGraph();
@@ -623,7 +655,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     worklist.push_back(node);
 
     LLVM_DEBUG({
-      llvm::dbgs() << "Created root group node with " << node->ops.size()
+      llvm::dbgs() << "Created root group node with " << node->size()
                    << " operations of type "
                    << (group.isLoadGroup() ? "Load" : "Store") << "\n";
     });
@@ -631,6 +663,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
 
   OperationsFingerprint fingerprints(graph);
 
+  // Process node uses, going top-down.
   auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
     Operation *user = use.getOwner();
     auto *existingNode = graph.getNodeForOp(user);
@@ -680,6 +713,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
     worklist.push_back(newNode);
   };
 
+  // Process node operands, going bottom-up.
   auto processOperands = [&](SLPGraphNode *node, Value operand, int64_t index) {
     Operation *srcOp = operand.getDefiningOp();
     if (!srcOp)
@@ -720,7 +754,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
 
   while (!worklist.empty()) {
     SLPGraphNode *node = worklist.pop_back_val();
-    LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->ops.size()
+    LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->size()
                             << " operations, first op: "
                             << node->ops.front()->getName() << "\n");
 

>From 910f7a094c061bd6f1152c194f71462a7220ec0f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 22:08:35 +0200
Subject: [PATCH 25/28] check arith.add indices

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 13 ++++-
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 54 +++++++++++++++++++
 2 files changed, 66 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 24059ec355b30..aa2f3108712f1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -145,7 +145,18 @@ static bool isAdjacentIndices(Value idx1, Value idx2) {
       return *c1 + 1 == *c2;
   }
 
-  // TODO: Check arith.add, affine.apply, etc
+  if (auto addOp2 = idx2.getDefiningOp<arith::AddIOp>()) {
+    if (addOp2.getLhs() == idx1 && getConstantIntValue(addOp2.getRhs()) == 1)
+      return true;
+
+    if (auto addOp1 = idx1.getDefiningOp<arith::AddIOp>()) {
+      if (addOp1.getLhs() == addOp2.getLhs() &&
+          isAdjacentIndices(addOp1.getRhs(), addOp2.getRhs()))
+        return true;
+    }
+  }
+
+  // TODO: affine.apply, etc
   return false;
 }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 2e9298d11ed05..edb722472995d 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -75,6 +75,60 @@ func.func @read_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
 }
 
 
+// CHECK-LABEL: func @read_write_add_index
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index)
+func.func @read_write_add_index(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index) {
+  // CHECK:     %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %ind1 = arith.addi %arg2, %c1 : index
+  %ind2 = arith.addi %arg2, %c2 : index
+  %ind3 = arith.addi %arg2, %c3 : index
+
+  %0 = memref.load %arg0[%arg2] : memref<8xi32>
+  %1 = memref.load %arg0[%ind1] : memref<8xi32>
+  %2 = memref.load %arg0[%ind2] : memref<8xi32>
+  %3 = memref.load %arg0[%ind3] : memref<8xi32>
+
+  memref.store %0, %arg0[%arg2] : memref<8xi32>
+  memref.store %1, %arg0[%ind1] : memref<8xi32>
+  memref.store %2, %arg0[%ind2] : memref<8xi32>
+  memref.store %3, %arg0[%ind3] : memref<8xi32>
+
+  return
+}
+
+
+// CHECK-LABEL: func @read_write_add_index_interleaved
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index)
+func.func @read_write_add_index_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index) {
+  // CHECK:     %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     vector.store %[[RES]], %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %ind1 = arith.addi %arg2, %c1 : index
+  %ind2 = arith.addi %arg2, %c2 : index
+  %ind3 = arith.addi %arg2, %c3 : index
+
+  %0 = memref.load %arg0[%arg2] : memref<8xi32>
+  %1 = memref.load %arg0[%ind1] : memref<8xi32>
+  %3 = memref.load %arg0[%ind3] : memref<8xi32>
+  %2 = memref.load %arg0[%ind2] : memref<8xi32>
+
+  memref.store %3, %arg0[%ind3] : memref<8xi32>
+  memref.store %0, %arg0[%arg2] : memref<8xi32>
+  memref.store %1, %arg0[%ind1] : memref<8xi32>
+  memref.store %2, %arg0[%ind2] : memref<8xi32>
+
+  return
+}
+
+
 // CHECK-LABEL: func @read_read_add
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
 func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){

>From 0db4c55a99cc7328a8e2b0233f545f9325eb393d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 23:41:34 +0200
Subject: [PATCH 26/28] fix vecor sizes

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 24 ++++++----
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 47 +++++++++++++++++++
 2 files changed, 61 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index aa2f3108712f1..dfd4747f615ee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -12,14 +12,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/SHA1.h"
 
@@ -371,15 +368,24 @@ class SLPGraph {
       }
     });
 
-    auto isGoodNode = [&](SLPGraphNode *node) {
+    auto isBadNode = [&](SLPGraphNode *node) {
       return node->users.empty() && node->operands.empty();
     };
 
-    IRMapping mapping;
+    // Update vec sizes if inputs are smaller.
     for (auto *node : sortedNodes) {
-      if (isGoodNode(node))
-        continue;
+      size_t size = node->size();
+      for (auto *operand : node->operands)
+        size = std::min(size, operand->size());
+
+      node->ops.resize(size);
+    }
+
+    // Remove nodes that are not good (have users or operands)
+    llvm::erase_if(sortedNodes, isBadNode);
 
+    IRMapping mapping;
+    for (auto *node : sortedNodes) {
       int64_t numElements = node->size();
       Operation *op = node->ops.front();
       rewriter.setInsertionPoint(op);
@@ -462,14 +468,12 @@ class SLPGraph {
     }
 
     for (auto *node : llvm::reverse(sortedNodes)) {
-      if (isGoodNode(node))
-        continue;
-
       for (Operation *op : node->ops) {
         rewriter.eraseOp(op);
       }
     }
 
+    LLVM_DEBUG(llvm::dbgs() << "Vectorization completed successfully\n");
     return success();
   }
 
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index edb722472995d..7ad077d8fd78c 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -356,3 +356,50 @@ func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
 
   return
 }
+
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func @read_read_add_write_interleaved_use
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[C3:.*]] = arith.constant 3 : index
+  // CHECK:     %[[V0:.*]] = memref.load %[[ARG0]][%[[C3]]] : memref<8xi32>
+  // CHECK:     %[[V1:.*]] = memref.load %[[ARG1]][%[[C3]]] : memref<8xi32>
+  // CHECK:     call @use(%[[V0]]) : (i32) -> ()
+  // CHECK:     %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : i32
+  // CHECK:     %[[V3:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+  // CHECK:     %[[V4:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+  // CHECK:     %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : vector<3xi32>
+  // CHECK:     vector.store %[[V5]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+  // CHECK:     memref.store %[[V2]], %[[ARG0]][%[[C3]]] : memref<8xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+  call @use(%3) : (i32) -> ()
+  %11 = arith.addi %3, %7 : i32
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %8 = arith.addi %0, %4 : i32
+
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %10 = arith.addi %2, %6 : i32
+
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %9 = arith.addi %1, %5 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+
+  return
+}

>From 08dcd13cb9ff7351f4eb690fc787eb71fd306369 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 23:48:08 +0200
Subject: [PATCH 27/28] fix op insertion point

---
 .../Vector/Transforms/SLPVectorizer.cpp       | 12 ++++-
 mlir/test/Dialect/Vector/slp-vectorize.mlir   | 44 +++++++++++++++++++
 2 files changed, 55 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index dfd4747f615ee..ab5bfd94de49c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -255,6 +255,16 @@ struct SLPGraphNode {
       : ops(operations.begin(), operations.end()) {}
 
   size_t size() const { return ops.size(); }
+
+  Operation *getEarliestOp() const {
+    assert(!ops.empty() && "empty node");
+    Operation *ret = ops.front();
+    for (Operation *op : ArrayRef(ops).drop_front()) {
+      if (op->isBeforeInBlock(ret))
+        ret = op;
+    }
+    return ret;
+  }
 };
 
 /// A graph of vectorizable operations
@@ -388,7 +398,7 @@ class SLPGraph {
     for (auto *node : sortedNodes) {
       int64_t numElements = node->size();
       Operation *op = node->ops.front();
-      rewriter.setInsertionPoint(op);
+      rewriter.setInsertionPoint(node->getEarliestOp());
       Location loc = op->getLoc();
 
       auto handleNonVectorInputs = [&](ValueRange operands) {
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 7ad077d8fd78c..9d06a1faa07b2 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -403,3 +403,47 @@ func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memr
 
   return
 }
+
+
+// CHECK-LABEL: func @read_read_add_write_interleaved_use_add
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+  // CHECK:     %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:     %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[V1:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
+  // CHECK:     %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     %[[V3:.*]] = vector.extract %[[V2]][3] : i32 from vector<4xi32>
+  // CHECK:     %[[V4:.*]] = arith.subi %[[V1]], %[[V3]] : i32
+  // CHECK:     %[[V5:.*]] = arith.addi %[[V0]], %[[V2]] : vector<4xi32>
+  // CHECK:     vector.store %[[V5]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+  // CHECK:     call @use(%[[V4]]) : (i32) -> ()
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+
+  %3 = memref.load %arg0[%c3] : memref<8xi32>
+  %7 = memref.load %arg1[%c3] : memref<8xi32>
+  %12 = arith.subi %3, %7 : i32
+  %11 = arith.addi %3, %7 : i32
+
+  %0 = memref.load %arg0[%c0] : memref<8xi32>
+  %4 = memref.load %arg1[%c0] : memref<8xi32>
+  %8 = arith.addi %0, %4 : i32
+
+  %2 = memref.load %arg0[%c2] : memref<8xi32>
+  %6 = memref.load %arg1[%c2] : memref<8xi32>
+  %10 = arith.addi %2, %6 : i32
+
+  %1 = memref.load %arg0[%c1] : memref<8xi32>
+  %5 = memref.load %arg1[%c1] : memref<8xi32>
+  %9 = arith.addi %1, %5 : i32
+
+  memref.store %8, %arg0[%c0] : memref<8xi32>
+  memref.store %11, %arg0[%c3] : memref<8xi32>
+  memref.store %10, %arg0[%c2] : memref<8xi32>
+  memref.store %9, %arg0[%c1] : memref<8xi32>
+
+  call @use(%12) : (i32) -> ()
+  return
+}

>From 8526de52a6aaaebac8818138a1f947d203942b32 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 19 May 2025 11:00:11 +0200
Subject: [PATCH 28/28] check same block

---
 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ab5bfd94de49c..ec1c41dbd7b69 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -659,6 +659,9 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
   if (op1->getRawDictionaryAttrs() != op2->getRawDictionaryAttrs())
     return false;
 
+  if (op1->getBlock() != op2->getBlock())
+    return false;
+
   return true;
 }
 



More information about the Mlir-commits mailing list