[Mlir-commits] [mlir] [mlir][vector] MLIR SLP vectorizer (PR #140469)
Ivan Butygin
llvmlistbot at llvm.org
Sun May 18 14:49:40 PDT 2025
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/140469
>From 36c7f1e004c3a872dff9005c7113b6c99760671c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:01:41 +0200
Subject: [PATCH 01/27] stubs
---
.../mlir/Dialect/Vector/Transforms/Passes.h | 3 +
.../mlir/Dialect/Vector/Transforms/Passes.td | 12 ++++
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
.../Vector/Transforms/SLPVectorizer.cpp | 63 +++++++++++++++++++
mlir/test/Dialect/Vector/slp-vectorize.mlir | 34 ++++++++++
5 files changed, 113 insertions(+)
create mode 100644 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
create mode 100644 mlir/test/Dialect/Vector/slp-vectorize.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..43112f084dc60 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -25,6 +25,9 @@ std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
VectorMultiReductionLowering option =
VectorMultiReductionLowering::InnerParallel);
+/// Creates a pass that implements the SLP vectorizer.
+std::unique_ptr<Pass> createSLPVectorizerPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..94ccd61cb5170 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,16 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
];
}
+def SLPVectorizer : Pass<"slp-vectorizer", "ModuleOp"> {
+ let summary = "SLP Vectorizer Pass";
+ let description = [{
+ This pass implements the SLP (Superword Level Parallelism) vectorizer.
+ It detects consecutive operations that can be put together into vector
+ operations. The pass works bottom-up, across basic blocks, in search of
+ scalars to combine.
+ }];
+ let constructor = "mlir::vector::createSLPVectorizerPass()";
+ let dependentDialects = ["mlir::vector::VectorDialect"];
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..37333b739bd86 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorStep.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
+ SLPVectorizer.cpp
SubsetOpInterfaceImpl.cpp
VectorDistribute.cpp
VectorDropLeadUnitDim.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
new file mode 100644
index 0000000000000..e9f3b12bc7461
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -0,0 +1,63 @@
+//===- SLPVectorizer.cpp - SLP Vectorizer Pass ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the SLP vectorizer pass for MLIR. The pass attempts to
+// combine similar independent operations into vector operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "slp-vectorizer"
+
+namespace mlir {
+namespace vector {
+#define GEN_PASS_DEF_SLPVECTORIZER
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace vector
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// This pass implements the SLP vectorizer. It detects consecutive operations
+/// that can be put together into vector operations. The pass works bottom-up,
+/// across basic blocks, in search of scalars to combine.
+struct SLPVectorizerPass
+ : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void SLPVectorizerPass::runOnOperation() {
+ Operation *op = getOperation();
+ MLIRContext *context = &getContext();
+
+ // TODO: Implement SLP vectorization logic
+ // 1. Find candidate operations for vectorization
+ // 2. Build vectorization trees
+ // 3. Perform vectorization if profitable
+ // 4. Clean up scalar operations
+
+ LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass\n");
+ llvm::errs() << "Running SLP Vectorizer pass\n";
+}
+
+std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
+ return std::make_unique<SLPVectorizerPass>();
+}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
new file mode 100644
index 0000000000000..31543f3a76b2e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -test-slp-vectorization | FileCheck %s
+
+// CHECK-LABEL: func @basic_slp
+func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: vector.transfer_read
+ // CHECK: arith.addi
+ // CHECK: vector.transfer_write
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+ %8 = arith.addi %0, %4 : i32
+ %9 = arith.addi %1, %5 : i32
+ %10 = arith.addi %2, %6 : i32
+ %11 = arith.addi %3, %7 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
+
+ return
+}
>From 2b4e64b0e7790b4b265c5924e1c0764ca1083a40 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:20:34 +0200
Subject: [PATCH 02/27] something working
---
.../Vector/Transforms/SLPVectorizer.cpp | 113 +++++++++++++++++-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 2 +-
2 files changed, 108 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e9f3b12bc7461..b696f36c82eee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
@@ -34,27 +35,127 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
+/// A group of consecutive memory operations of the same type (load or store)
+/// that can potentially be vectorized together.
+struct MemoryOpGroup {
+ enum class Type { Load, Store };
+ Type type;
+ SmallVector<Operation *> ops;
+
+ MemoryOpGroup(Type t) : type(t) {}
+
+ bool isLoadGroup() const { return type == Type::Load; }
+ bool isStoreGroup() const { return type == Type::Store; }
+
+ size_t size() const { return ops.size(); }
+ bool empty() const { return ops.empty(); }
+};
+
/// This pass implements the SLP vectorizer. It detects consecutive operations
/// that can be put together into vector operations. The pass works bottom-up,
/// across basic blocks, in search of scalars to combine.
struct SLPVectorizerPass
: public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
void runOnOperation() override;
+
+private:
+ /// Collect all memory operations in the block into groups.
+ /// Each group contains either all loads or all stores, uninterrupted by
+ /// operations of the other type.
+ SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
};
} // namespace
+SmallVector<MemoryOpGroup>
+SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
+ SmallVector<MemoryOpGroup> groups;
+ MemoryOpGroup *currentGroup = nullptr;
+
+ LLVM_DEBUG(llvm::dbgs() << "Scanning block for memory operations...\n");
+
+ for (Operation &op : block) {
+ LLVM_DEBUG(llvm::dbgs() << "Checking operation: " << op.getName() << "\n");
+
+ // Skip non-memory operations
+ if (!isa<memref::LoadOp, memref::StoreOp>(op)) {
+ LLVM_DEBUG(llvm::dbgs() << " Not a memory operation\n");
+ continue;
+ }
+
+ bool isLoad = isa<memref::LoadOp>(op);
+ MemoryOpGroup::Type type =
+ isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
+
+ LLVM_DEBUG(llvm::dbgs()
+ << " Found " << (isLoad ? "load" : "store") << " operation\n");
+
+ // Start a new group if:
+ // 1. We don't have a current group, or
+ // 2. The current operation is a different type than the current group
+ if (!currentGroup || currentGroup->type != type) {
+ LLVM_DEBUG(llvm::dbgs() << " Starting new group\n");
+ groups.emplace_back(type);
+ currentGroup = &groups.back();
+ }
+
+ currentGroup->ops.push_back(&op);
+ }
+
+ // Remove empty groups
+ groups.erase(std::remove_if(groups.begin(), groups.end(),
+ [](const MemoryOpGroup &g) { return g.empty(); }),
+ groups.end());
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Found " << groups.size() << " memory operation groups:\n";
+ for (const auto &group : groups) {
+ llvm::dbgs() << " Group type: "
+ << (group.isLoadGroup() ? "Load" : "Store")
+ << ", size: " << group.size() << "\n";
+ }
+ });
+
+ return groups;
+}
+
void SLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
MLIRContext *context = &getContext();
- // TODO: Implement SLP vectorization logic
- // 1. Find candidate operations for vectorization
- // 2. Build vectorization trees
- // 3. Perform vectorization if profitable
- // 4. Clean up scalar operations
+ LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass on operation: "
+ << op->getName() << "\n");
+
+ // Process each function in the module
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region) {
+ for (Operation &op : block) {
+ // If this is a function, process its body
+ if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Processing function: " << funcOp.getName() << "\n");
+
+ // Process each block in the function
+ for (Block &funcBlock : funcOp.getBody()) {
+ // Collect memory operation groups
+ SmallVector<MemoryOpGroup> groups =
+ collectMemoryOpGroups(funcBlock);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Found " << groups.size()
+ << " memory operation groups:\n";
+ for (const auto &group : groups) {
+ llvm::dbgs() << " Group type: "
+ << (group.isLoadGroup() ? "Load" : "Store")
+ << ", size: " << group.size() << "\n";
+ }
+ });
+ }
+ }
+ }
+ }
+ }
- LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass\n");
llvm::errs() << "Running SLP Vectorizer pass\n";
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 31543f3a76b2e..a07dd05dd16aa 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-slp-vectorization | FileCheck %s
+// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
// CHECK-LABEL: func @basic_slp
func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
>From 14804ce4fa80df248e435998a40dfc953e463324 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:29:20 +0200
Subject: [PATCH 03/27] block walk
---
.../Vector/Transforms/SLPVectorizer.cpp | 64 ++++---------------
1 file changed, 11 insertions(+), 53 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index b696f36c82eee..bec5f9d90b21b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -72,29 +72,19 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
SmallVector<MemoryOpGroup> groups;
MemoryOpGroup *currentGroup = nullptr;
- LLVM_DEBUG(llvm::dbgs() << "Scanning block for memory operations...\n");
-
for (Operation &op : block) {
- LLVM_DEBUG(llvm::dbgs() << "Checking operation: " << op.getName() << "\n");
-
// Skip non-memory operations
- if (!isa<memref::LoadOp, memref::StoreOp>(op)) {
- LLVM_DEBUG(llvm::dbgs() << " Not a memory operation\n");
+ if (!isa<memref::LoadOp, memref::StoreOp>(op))
continue;
- }
bool isLoad = isa<memref::LoadOp>(op);
MemoryOpGroup::Type type =
isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
- LLVM_DEBUG(llvm::dbgs()
- << " Found " << (isLoad ? "load" : "store") << " operation\n");
-
// Start a new group if:
// 1. We don't have a current group, or
// 2. The current operation is a different type than the current group
if (!currentGroup || currentGroup->type != type) {
- LLVM_DEBUG(llvm::dbgs() << " Starting new group\n");
groups.emplace_back(type);
currentGroup = &groups.back();
}
@@ -107,15 +97,6 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
[](const MemoryOpGroup &g) { return g.empty(); }),
groups.end());
- LLVM_DEBUG({
- llvm::dbgs() << "Found " << groups.size() << " memory operation groups:\n";
- for (const auto &group : groups) {
- llvm::dbgs() << " Group type: "
- << (group.isLoadGroup() ? "Load" : "Store")
- << ", size: " << group.size() << "\n";
- }
- });
-
return groups;
}
@@ -123,40 +104,17 @@ void SLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
MLIRContext *context = &getContext();
- LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass on operation: "
- << op->getName() << "\n");
-
- // Process each function in the module
- for (Region ®ion : op->getRegions()) {
- for (Block &block : region) {
- for (Operation &op : block) {
- // If this is a function, process its body
- if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Processing function: " << funcOp.getName() << "\n");
-
- // Process each block in the function
- for (Block &funcBlock : funcOp.getBody()) {
- // Collect memory operation groups
- SmallVector<MemoryOpGroup> groups =
- collectMemoryOpGroups(funcBlock);
-
- LLVM_DEBUG({
- llvm::dbgs() << "Found " << groups.size()
- << " memory operation groups:\n";
- for (const auto &group : groups) {
- llvm::dbgs() << " Group type: "
- << (group.isLoadGroup() ? "Load" : "Store")
- << ", size: " << group.size() << "\n";
- }
- });
- }
- }
- }
- }
- }
+ // Walk all blocks recursively
+ op->walk([&](Block *block) {
+ LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
+ << block->getParentOp()->getName() << "\n");
- llvm::errs() << "Running SLP Vectorizer pass\n";
+ // Collect memory operation groups
+ SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
+
+ LLVM_DEBUG(llvm::dbgs() << "Found " << groups.size()
+ << " memory operation groups in block\n");
+ });
}
std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
>From b2c5c5cdd73aac4374bc6015159ec08f95b06082 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:10:20 +0200
Subject: [PATCH 04/27] contiguous groups
---
.../Vector/Transforms/SLPVectorizer.cpp | 106 +++++++++++++++++-
1 file changed, 104 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index bec5f9d90b21b..f46dc71537ef3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -51,6 +51,96 @@ struct MemoryOpGroup {
bool empty() const { return ops.empty(); }
};
+// Extract contiguous groups from a MemoryOpGroup
+SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
+ SmallVector<MemoryOpGroup> result;
+ if (group.ops.empty())
+ return result;
+
+ // Keep track of which operations we've processed
+ DenseSet<Operation *> processedOps;
+
+ // Process each operation
+ for (Operation *op : group.ops) {
+ // Skip if we've already processed this operation
+ if (processedOps.contains(op))
+ continue;
+
+ // Get base and index of current operation
+ Value base;
+ int64_t index = -1;
+ if (group.isLoadGroup()) {
+ auto loadOp = cast<memref::LoadOp>(op);
+ if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
+ index = *value;
+ base = loadOp.getMemRef();
+ }
+ } else {
+ auto storeOp = cast<memref::StoreOp>(op);
+ if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
+ index = *value;
+ base = storeOp.getMemRef();
+ }
+ }
+ if (index == -1)
+ continue;
+
+ // Start a new group with this operation
+ result.emplace_back(group.type);
+ MemoryOpGroup ¤tGroup = result.back();
+ currentGroup.ops.push_back(op);
+ processedOps.insert(op);
+
+ LLVM_DEBUG(llvm::dbgs() << "Starting new group at base " << base
+ << " index " << index << "\n");
+
+ // Try to find operations with adjacent indices
+ bool foundMore;
+ do {
+ foundMore = false;
+ // Look for operations with index+1
+ for (Operation *otherOp : group.ops) {
+ if (processedOps.contains(otherOp))
+ continue;
+
+ Value otherBase;
+ int64_t otherIndex = -1;
+ if (group.isLoadGroup()) {
+ auto loadOp = cast<memref::LoadOp>(otherOp);
+ if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
+ otherIndex = *value;
+ otherBase = loadOp.getMemRef();
+ }
+ } else {
+ auto storeOp = cast<memref::StoreOp>(otherOp);
+ if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
+ otherIndex = *value;
+ otherBase = storeOp.getMemRef();
+ }
+ }
+
+ // Check if this operation has the same base and adjacent index
+ if (otherIndex != -1 && otherBase == base &&
+ otherIndex == currentGroup.ops.size()) {
+ currentGroup.ops.push_back(otherOp);
+ processedOps.insert(otherOp);
+ foundMore = true;
+ LLVM_DEBUG(llvm::dbgs()
+ << "Added operation with index " << otherIndex << "\n");
+ break;
+ }
+ }
+ } while (foundMore);
+ }
+
+ // Remove empty groups
+ result.erase(std::remove_if(result.begin(), result.end(),
+ [](const MemoryOpGroup &g) { return g.empty(); }),
+ result.end());
+
+ return result;
+}
+
/// This pass implements the SLP vectorizer. It detects consecutive operations
/// that can be put together into vector operations. The pass works bottom-up,
/// across basic blocks, in search of scalars to combine.
@@ -112,8 +202,20 @@ void SLPVectorizerPass::runOnOperation() {
// Collect memory operation groups
SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
- LLVM_DEBUG(llvm::dbgs() << "Found " << groups.size()
- << " memory operation groups in block\n");
+ // Process each group to find contiguous sequences
+ for (const auto &group : groups) {
+ SmallVector<MemoryOpGroup> contiguousGroups =
+ extractContiguousGroups(group);
+ LLVM_DEBUG({
+ llvm::dbgs() << "Found " << contiguousGroups.size()
+ << " contiguous groups in "
+ << (group.isLoadGroup() ? "load" : "store") << " group\n";
+ for (const auto &contigGroup : contiguousGroups) {
+ llvm::dbgs() << " Contiguous group with " << contigGroup.size()
+ << " operations\n";
+ }
+ });
+ }
});
}
>From e01213d8ec4284b24732050f14caca64d18a0123 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:14:53 +0200
Subject: [PATCH 05/27] refac
---
.../Vector/Transforms/SLPVectorizer.cpp | 55 ++++++++-----------
1 file changed, 22 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index f46dc71537ef3..9a0ba5264bc40 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -51,6 +51,18 @@ struct MemoryOpGroup {
bool empty() const { return ops.empty(); }
};
+// Helper function to extract base and index from a memory operation
+std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
+ if (auto value = getConstantIntValue(loadOp.getIndices().front()))
+ return std::make_pair(loadOp.getMemRef(), *value);
+ } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
+ if (auto value = getConstantIntValue(storeOp.getIndices().front()))
+ return std::make_pair(storeOp.getMemRef(), *value);
+ }
+ return std::nullopt;
+}
+
// Extract contiguous groups from a MemoryOpGroup
SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
SmallVector<MemoryOpGroup> result;
@@ -67,24 +79,12 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
continue;
// Get base and index of current operation
- Value base;
- int64_t index = -1;
- if (group.isLoadGroup()) {
- auto loadOp = cast<memref::LoadOp>(op);
- if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
- index = *value;
- base = loadOp.getMemRef();
- }
- } else {
- auto storeOp = cast<memref::StoreOp>(op);
- if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
- index = *value;
- base = storeOp.getMemRef();
- }
- }
- if (index == -1)
+ auto baseAndIndex = getBaseAndIndex(op);
+ if (!baseAndIndex)
continue;
+ auto [base, index] = *baseAndIndex;
+
// Start a new group with this operation
result.emplace_back(group.type);
MemoryOpGroup ¤tGroup = result.back();
@@ -103,25 +103,14 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
if (processedOps.contains(otherOp))
continue;
- Value otherBase;
- int64_t otherIndex = -1;
- if (group.isLoadGroup()) {
- auto loadOp = cast<memref::LoadOp>(otherOp);
- if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
- otherIndex = *value;
- otherBase = loadOp.getMemRef();
- }
- } else {
- auto storeOp = cast<memref::StoreOp>(otherOp);
- if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
- otherIndex = *value;
- otherBase = storeOp.getMemRef();
- }
- }
+ auto otherBaseAndIndex = getBaseAndIndex(otherOp);
+ if (!otherBaseAndIndex)
+ continue;
+
+ auto [otherBase, otherIndex] = *otherBaseAndIndex;
// Check if this operation has the same base and adjacent index
- if (otherIndex != -1 && otherBase == base &&
- otherIndex == currentGroup.ops.size()) {
+ if (otherBase == base && otherIndex == currentGroup.ops.size()) {
currentGroup.ops.push_back(otherOp);
processedOps.insert(otherOp);
foundMore = true;
>From 99158188e89ab1d35947370e8f5be7369d875b97 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:35:42 +0200
Subject: [PATCH 06/27] SLPGraph
---
.../Vector/Transforms/SLPVectorizer.cpp | 162 ++++++++++++++++++
1 file changed, 162 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 9a0ba5264bc40..4355dc33648c0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -130,6 +130,160 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
return result;
}
+/// A node in the SLP graph representing a vectorizable operation
+struct SLPGraphNode {
+ Operation *op;
+ DenseSet<SLPGraphNode *> users;
+ DenseSet<SLPGraphNode *> operands;
+ bool isRoot = false;
+
+ SLPGraphNode(Operation *op) : op(op) {}
+};
+
+/// A graph of vectorizable operations
+class SLPGraph {
+public:
+ SLPGraph() = default;
+ ~SLPGraph() {
+ for (auto *node : nodes)
+ delete node;
+ }
+
+ /// Add a new node to the graph
+ SLPGraphNode *addNode(Operation *op) {
+ nodes.push_back(new SLPGraphNode(op));
+ return nodes.back();
+ }
+
+ /// Add a root node (memory operation)
+ SLPGraphNode *addRoot(Operation *op) {
+ auto *node = addNode(op);
+ node->isRoot = true;
+ return node;
+ }
+
+ /// Add a dependency edge between nodes
+ void addEdge(SLPGraphNode *from, SLPGraphNode *to) {
+ from->users.insert(to);
+ to->operands.insert(from);
+ }
+
+ /// Get all root nodes
+ SmallVector<SLPGraphNode *> getRoots() const {
+ SmallVector<SLPGraphNode *> roots;
+ for (auto *node : nodes)
+ if (node->isRoot)
+ roots.push_back(node);
+ return roots;
+ }
+
+ /// Print the graph structure
+ void print() const {
+ llvm::dbgs() << "SLP Graph Structure:\n";
+ llvm::dbgs() << "===================\n";
+
+ // First print all roots
+ llvm::dbgs() << "Roots:\n";
+ for (auto *node : nodes) {
+ if (!node->isRoot)
+ continue;
+ llvm::dbgs() << " " << *node->op << "\n";
+ llvm::dbgs() << " Users: ";
+ for (auto *user : node->users) {
+ llvm::dbgs() << "\n " << *user->op;
+ }
+ llvm::dbgs() << "\n";
+ }
+
+ // Then print all non-root nodes
+ llvm::dbgs() << "\nNon-root nodes:\n";
+ for (auto *node : nodes) {
+ if (node->isRoot)
+ continue;
+ llvm::dbgs() << " " << *node->op << "\n";
+ llvm::dbgs() << " Operands: ";
+ for (auto *operand : node->operands) {
+ llvm::dbgs() << "\n " << *operand->op;
+ }
+ llvm::dbgs() << "\n Users: ";
+ for (auto *user : node->users) {
+ llvm::dbgs() << "\n " << *user->op;
+ }
+ llvm::dbgs() << "\n";
+ }
+ llvm::dbgs() << "===================\n";
+ }
+
+private:
+ SmallVector<SLPGraphNode *> nodes;
+};
+
+/// Build the SLP graph starting from memory operation roots
+SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+ SLPGraph graph;
+ DenseMap<Operation *, SLPGraphNode *> opToNode;
+
+ // First, add all memory operations as roots
+ for (const auto &group : rootGroups) {
+ for (Operation *op : group.ops) {
+ opToNode[op] = graph.addRoot(op);
+ }
+ }
+
+ // Process each root group to build the graph
+ for (const auto &group : rootGroups) {
+ for (Operation *rootOp : group.ops) {
+ // Get the value produced by this memory operation
+ Value rootValue = group.isLoadGroup()
+ ? cast<memref::LoadOp>(rootOp).getResult()
+ : cast<memref::StoreOp>(rootOp).getValue();
+
+ // Find all users of this value
+ for (Operation *user : rootValue.getUsers()) {
+ // Skip if we've already processed this operation
+ if (opToNode.contains(user))
+ continue;
+
+ // Check if this is a vectorizable operation
+ if (isa<arith::AddFOp, arith::AddIOp, arith::SubFOp, arith::SubIOp,
+ arith::MulFOp, arith::MulIOp>(user)) {
+ // Check if at least one other operand is already in the graph
+ bool hasGraphOperand = false;
+ for (Value operand : user->getOperands()) {
+ if (operand == rootValue)
+ continue;
+ if (auto *defOp = operand.getDefiningOp()) {
+ if (opToNode.contains(defOp)) {
+ hasGraphOperand = true;
+ break;
+ }
+ }
+ }
+
+ // Only add the operation if it has at least one other operand in the
+ // graph
+ if (hasGraphOperand) {
+ auto *node = graph.addNode(user);
+ opToNode[user] = node;
+ graph.addEdge(opToNode[rootOp], node);
+
+ // Add edges from other operands that are in the graph
+ for (Value operand : user->getOperands()) {
+ if (auto *defOp = operand.getDefiningOp()) {
+ if (opToNode.contains(defOp)) {
+ graph.addEdge(opToNode[defOp], node);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return graph;
+}
+
/// This pass implements the SLP vectorizer. It detects consecutive operations
/// that can be put together into vector operations. The pass works bottom-up,
/// across basic blocks, in search of scalars to combine.
@@ -192,6 +346,7 @@ void SLPVectorizerPass::runOnOperation() {
SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
// Process each group to find contiguous sequences
+ SmallVector<MemoryOpGroup> rootGroups;
for (const auto &group : groups) {
SmallVector<MemoryOpGroup> contiguousGroups =
extractContiguousGroups(group);
@@ -204,7 +359,14 @@ void SLPVectorizerPass::runOnOperation() {
<< " operations\n";
}
});
+ rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
}
+
+ // Build the SLP graph from root groups
+ SLPGraph graph = buildSLPGraph(rootGroups);
+
+ // Print the graph structure
+ LLVM_DEBUG(graph.print());
});
}
>From 4a5409137117ac56f1b4d6b8f8f28f7eb6291f22 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:53:58 +0200
Subject: [PATCH 07/27] SLPGraph
---
.../Vector/Transforms/SLPVectorizer.cpp | 129 ++++++------------
1 file changed, 38 insertions(+), 91 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 4355dc33648c0..3c4fc3a377244 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -130,29 +130,28 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
return result;
}
-/// A node in the SLP graph representing a vectorizable operation
+/// A node in the SLP graph representing a group of vectorizable operations
struct SLPGraphNode {
- Operation *op;
+ SmallVector<Operation *> ops;
DenseSet<SLPGraphNode *> users;
DenseSet<SLPGraphNode *> operands;
bool isRoot = false;
- SLPGraphNode(Operation *op) : op(op) {}
+ SLPGraphNode() = default;
+ SLPGraphNode(Operation *op) { ops.push_back(op); }
+ void addOp(Operation *op) { ops.push_back(op); }
};
/// A graph of vectorizable operations
class SLPGraph {
public:
SLPGraph() = default;
- ~SLPGraph() {
- for (auto *node : nodes)
- delete node;
- }
+ ~SLPGraph() = default;
/// Add a new node to the graph
SLPGraphNode *addNode(Operation *op) {
- nodes.push_back(new SLPGraphNode(op));
- return nodes.back();
+ nodes.push_back(std::make_unique<SLPGraphNode>(op));
+ return nodes.back().get();
}
/// Add a root node (memory operation)
@@ -171,9 +170,9 @@ class SLPGraph {
/// Get all root nodes
SmallVector<SLPGraphNode *> getRoots() const {
SmallVector<SLPGraphNode *> roots;
- for (auto *node : nodes)
+ for (const auto &node : nodes)
if (node->isRoot)
- roots.push_back(node);
+ roots.push_back(node.get());
return roots;
}
@@ -184,30 +183,50 @@ class SLPGraph {
// First print all roots
llvm::dbgs() << "Roots:\n";
- for (auto *node : nodes) {
+ for (const auto &node : nodes) {
if (!node->isRoot)
continue;
- llvm::dbgs() << " " << *node->op << "\n";
+ llvm::dbgs() << " "
+ << (isa<memref::LoadOp>(node->ops[0]) ? "LOAD" : "STORE")
+ << " group with " << node->ops.size() << " operations:\n";
+ for (auto *op : node->ops) {
+ llvm::dbgs() << " " << *op << "\n";
+ }
llvm::dbgs() << " Users: ";
for (auto *user : node->users) {
- llvm::dbgs() << "\n " << *user->op;
+ llvm::dbgs() << "\n Group with " << user->ops.size()
+ << " operations:";
+ for (auto *op : user->ops) {
+ llvm::dbgs() << "\n " << *op;
+ }
}
llvm::dbgs() << "\n";
}
// Then print all non-root nodes
llvm::dbgs() << "\nNon-root nodes:\n";
- for (auto *node : nodes) {
+ for (const auto &node : nodes) {
if (node->isRoot)
continue;
- llvm::dbgs() << " " << *node->op << "\n";
+ llvm::dbgs() << " Group with " << node->ops.size() << " operations:\n";
+ for (auto *op : node->ops) {
+ llvm::dbgs() << " " << *op << "\n";
+ }
llvm::dbgs() << " Operands: ";
for (auto *operand : node->operands) {
- llvm::dbgs() << "\n " << *operand->op;
+ llvm::dbgs() << "\n Group with " << operand->ops.size()
+ << " operations:";
+ for (auto *op : operand->ops) {
+ llvm::dbgs() << "\n " << *op;
+ }
}
llvm::dbgs() << "\n Users: ";
for (auto *user : node->users) {
- llvm::dbgs() << "\n " << *user->op;
+ llvm::dbgs() << "\n Group with " << user->ops.size()
+ << " operations:";
+ for (auto *op : user->ops) {
+ llvm::dbgs() << "\n " << *op;
+ }
}
llvm::dbgs() << "\n";
}
@@ -215,75 +234,9 @@ class SLPGraph {
}
private:
- SmallVector<SLPGraphNode *> nodes;
+ SmallVector<std::unique_ptr<SLPGraphNode>> nodes;
};
-/// Build the SLP graph starting from memory operation roots
-SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
- SLPGraph graph;
- DenseMap<Operation *, SLPGraphNode *> opToNode;
-
- // First, add all memory operations as roots
- for (const auto &group : rootGroups) {
- for (Operation *op : group.ops) {
- opToNode[op] = graph.addRoot(op);
- }
- }
-
- // Process each root group to build the graph
- for (const auto &group : rootGroups) {
- for (Operation *rootOp : group.ops) {
- // Get the value produced by this memory operation
- Value rootValue = group.isLoadGroup()
- ? cast<memref::LoadOp>(rootOp).getResult()
- : cast<memref::StoreOp>(rootOp).getValue();
-
- // Find all users of this value
- for (Operation *user : rootValue.getUsers()) {
- // Skip if we've already processed this operation
- if (opToNode.contains(user))
- continue;
-
- // Check if this is a vectorizable operation
- if (isa<arith::AddFOp, arith::AddIOp, arith::SubFOp, arith::SubIOp,
- arith::MulFOp, arith::MulIOp>(user)) {
- // Check if at least one other operand is already in the graph
- bool hasGraphOperand = false;
- for (Value operand : user->getOperands()) {
- if (operand == rootValue)
- continue;
- if (auto *defOp = operand.getDefiningOp()) {
- if (opToNode.contains(defOp)) {
- hasGraphOperand = true;
- break;
- }
- }
- }
-
- // Only add the operation if it has at least one other operand in the
- // graph
- if (hasGraphOperand) {
- auto *node = graph.addNode(user);
- opToNode[user] = node;
- graph.addEdge(opToNode[rootOp], node);
-
- // Add edges from other operands that are in the graph
- for (Value operand : user->getOperands()) {
- if (auto *defOp = operand.getDefiningOp()) {
- if (opToNode.contains(defOp)) {
- graph.addEdge(opToNode[defOp], node);
- }
- }
- }
- }
- }
- }
- }
- }
-
- return graph;
-}
-
/// This pass implements the SLP vectorizer. It detects consecutive operations
/// that can be put together into vector operations. The pass works bottom-up,
/// across basic blocks, in search of scalars to combine.
@@ -361,12 +314,6 @@ void SLPVectorizerPass::runOnOperation() {
});
rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
}
-
- // Build the SLP graph from root groups
- SLPGraph graph = buildSLPGraph(rootGroups);
-
- // Print the graph structure
- LLVM_DEBUG(graph.print());
});
}
>From 96e0fe8232a27d47017b4e7366913b4e5966ef78 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 01:04:27 +0200
Subject: [PATCH 08/27] work
---
.../Vector/Transforms/SLPVectorizer.cpp | 48 ++++++++++++++++---
1 file changed, 41 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3c4fc3a377244..8e49b622ac39b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -138,8 +138,8 @@ struct SLPGraphNode {
bool isRoot = false;
SLPGraphNode() = default;
- SLPGraphNode(Operation *op) { ops.push_back(op); }
- void addOp(Operation *op) { ops.push_back(op); }
+ SLPGraphNode(ArrayRef<Operation *> operations)
+ : ops(operations.begin(), operations.end()) {}
};
/// A graph of vectorizable operations
@@ -148,15 +148,23 @@ class SLPGraph {
SLPGraph() = default;
~SLPGraph() = default;
+ // Delete copy constructor and assignment operator
+ SLPGraph(const SLPGraph &) = delete;
+ SLPGraph &operator=(const SLPGraph &) = delete;
+
+ // Allow move operations
+ SLPGraph(SLPGraph &&) = default;
+ SLPGraph &operator=(SLPGraph &&) = default;
+
/// Add a new node to the graph
- SLPGraphNode *addNode(Operation *op) {
- nodes.push_back(std::make_unique<SLPGraphNode>(op));
+ SLPGraphNode *addNode(ArrayRef<Operation *> operations) {
+ nodes.push_back(std::make_unique<SLPGraphNode>(operations));
return nodes.back().get();
}
/// Add a root node (memory operation)
- SLPGraphNode *addRoot(Operation *op) {
- auto *node = addNode(op);
+ SLPGraphNode *addRoot(ArrayRef<Operation *> operations) {
+ auto *node = addNode(operations);
node->isRoot = true;
return node;
}
@@ -251,7 +259,25 @@ struct SLPVectorizerPass
SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
};
-} // namespace
+/// Build the SLP graph starting from memory operation groups
+SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+ SLPGraph graph;
+
+ // First, create nodes for each contiguous memory operation group
+ for (const auto &group : rootGroups) {
+ // Create a new node for this group
+ auto *node = graph.addRoot(group.ops);
+ node->isRoot = true;
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Created " << (group.isLoadGroup() ? "LOAD" : "STORE")
+ << " group node with " << node->ops.size()
+ << " operations\n";
+ });
+ }
+
+ return graph;
+}
SmallVector<MemoryOpGroup>
SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
@@ -314,9 +340,17 @@ void SLPVectorizerPass::runOnOperation() {
});
rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
}
+
+ // Build the SLP graph from root groups
+ SLPGraph graph = buildSLPGraph(rootGroups);
+
+ // Print the graph structure
+ LLVM_DEBUG(graph.print());
});
}
+} // namespace
+
std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
return std::make_unique<SLPVectorizerPass>();
}
>From c19e55782daed0c57bb4604192dc73fa50ff4709 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 12:55:28 +0200
Subject: [PATCH 09/27] fingerprinting
---
.../Vector/Transforms/SLPVectorizer.cpp | 170 ++++++++++++++++--
1 file changed, 158 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 8e49b622ac39b..3e6a4ca05f87d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -21,6 +21,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/SHA1.h"
#define DEBUG_TYPE "slp-vectorizer"
@@ -64,7 +65,8 @@ std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
}
// Extract contiguous groups from a MemoryOpGroup
-SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
+static SmallVector<MemoryOpGroup>
+extractContiguousGroups(const MemoryOpGroup &group) {
SmallVector<MemoryOpGroup> result;
if (group.ops.empty())
return result;
@@ -133,8 +135,8 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
/// A node in the SLP graph representing a group of vectorizable operations
struct SLPGraphNode {
SmallVector<Operation *> ops;
- DenseSet<SLPGraphNode *> users;
- DenseSet<SLPGraphNode *> operands;
+ llvm::SmallDenseSet<SLPGraphNode *> users;
+ llvm::SmallDenseSet<SLPGraphNode *> operands;
bool isRoot = false;
SLPGraphNode() = default;
@@ -148,11 +150,9 @@ class SLPGraph {
SLPGraph() = default;
~SLPGraph() = default;
- // Delete copy constructor and assignment operator
SLPGraph(const SLPGraph &) = delete;
SLPGraph &operator=(const SLPGraph &) = delete;
- // Allow move operations
SLPGraph(SLPGraph &&) = default;
SLPGraph &operator=(SLPGraph &&) = default;
@@ -259,21 +259,168 @@ struct SLPVectorizerPass
SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
};
+static bool isVectorizable(Operation *op) {
+ return OpTrait::hasElementwiseMappableTraits(op);
+}
+
+using Fingerprint = std::array<uint8_t, 20>;
+
+template <typename T>
+static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
+ hasher.update(
+ ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
+}
+
+struct OperationsFingerprint {
+ OperationsFingerprint(
+ const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode)
+ : opToNode(opToNode) {}
+
+ Fingerprint getFingerprint(Operation *op) {
+ auto it = fingerprints.find(op);
+ if (it != fingerprints.end())
+ return it->second;
+
+ SmallVector<Operation *> worklist;
+ SmallVector<Operation *> toposortedOps;
+ worklist.emplace_back(op);
+ while (!worklist.empty()) {
+ Operation *op = worklist.pop_back_val();
+ toposortedOps.emplace_back(op);
+ if (opToNode.contains(op))
+ continue;
+
+ for (Value operand : op->getOperands()) {
+ auto *defOp = operand.getDefiningOp();
+ if (!defOp || !isVectorizable(defOp))
+ continue;
+
+ toposortedOps.emplace_back(defOp);
+ worklist.emplace_back(defOp);
+ }
+ }
+
+ for (Operation *op : llvm::reverse(toposortedOps)) {
+ llvm::SHA1 hasher;
+ addDataToHash(hasher, op->getName().getTypeID());
+ addDataToHash(hasher, op->getRawDictionaryAttrs());
+ addDataToHash(hasher, op->hashProperties());
+ for (Value operand : op->getOperands()) {
+ auto *defOp = operand.getDefiningOp();
+ if (!defOp)
+ continue;
+
+ auto it1 = opToNode.find(defOp);
+ if (it1 != opToNode.end()) {
+ addDataToHash(hasher, it1->second);
+ continue;
+ }
+
+ auto it2 = fingerprints.find(defOp);
+ if (it2 != fingerprints.end()) {
+ addDataToHash(hasher, it2->second);
+ continue;
+ }
+ }
+ fingerprints[op] = hasher.result();
+ }
+
+ return fingerprints[op];
+ }
+
+ void invalidate(Operation *op) {
+ if (fingerprints.contains(op))
+ fingerprints.clear();
+ }
+
+ const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode;
+ DenseMap<Operation *, Fingerprint> fingerprints;
+};
+
+static bool isEquivalent(Operation *op1, Operation *op2) {
+ if (op1->getName() != op2->getName())
+ return false;
+
+ if (op1->getRawDictionaryAttrs() != op2->getRawDictionaryAttrs())
+ return false;
+
+ return true;
+}
+
/// Build the SLP graph starting from memory operation groups
-SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
SLPGraph graph;
+ llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
+
+ SmallVector<SLPGraphNode *> worklist;
// First, create nodes for each contiguous memory operation group
for (const auto &group : rootGroups) {
- // Create a new node for this group
auto *node = graph.addRoot(group.ops);
- node->isRoot = true;
+ for (Operation *op : group.ops)
+ opToNode[op] = node;
+
+ worklist.push_back(node);
LLVM_DEBUG({
- llvm::dbgs() << "Created " << (group.isLoadGroup() ? "LOAD" : "STORE")
- << " group node with " << node->ops.size()
- << " operations\n";
+ llvm::dbgs() << "Created root group node with " << node->ops.size()
+ << " operations of type "
+ << (group.type == MemoryOpGroup::Type::Load ? "Load"
+ : "Store")
+ << "\n";
});
+
+ OperationsFingerprint fingerprints(opToNode);
+
+ auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
+ Operation *user = use.getOwner();
+ if (opToNode.contains(user) || !isVectorizable(user))
+ return;
+
+ Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
+
+ SmallVector<Operation *> currentOps;
+ currentOps.emplace_back(user);
+ for (Operation *op : ArrayRef(node->ops).drop_front()) {
+ Operation *found = nullptr;
+ for (OpOperand &opUse : op->getUses()) {
+ if (opUse.getOperandNumber() != use.getOperandNumber())
+ continue;
+
+ Operation *useOwner = opUse.getOwner();
+ if (!isEquivalent(useOwner, user) ||
+ fingerprints.getFingerprint(useOwner) != expectedFingerprint)
+ continue;
+
+ found = useOwner;
+ break;
+ }
+ if (!found)
+ break;
+
+ currentOps.push_back(found);
+ }
+
+ if (currentOps.size() == 1)
+ return;
+
+ auto *newNode = graph.addNode(currentOps);
+ graph.addEdge(node, newNode);
+ for (Operation *op : currentOps) {
+ opToNode[op] = newNode;
+ fingerprints.invalidate(op);
+ }
+
+ worklist.push_back(newNode);
+ };
+
+ while (!worklist.empty()) {
+ SLPGraphNode *node = worklist.pop_back_val();
+
+ Operation *op = node->ops.front();
+ for (OpOperand &use : op->getUses())
+ processUse(node, use);
+ }
}
return graph;
@@ -314,7 +461,6 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
void SLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
- MLIRContext *context = &getContext();
// Walk all blocks recursively
op->walk([&](Block *block) {
>From dfdfc8948cdfd70d068234d310a422eeaef45e3a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 13:58:50 +0200
Subject: [PATCH 10/27] graph
---
.../Vector/Transforms/SLPVectorizer.cpp | 106 ++++++++++--------
1 file changed, 62 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3e6a4ca05f87d..28c53efea7512 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -185,7 +185,7 @@ class SLPGraph {
}
/// Print the graph structure
- void print() const {
+ [[maybe_unused]] void print() const {
llvm::dbgs() << "SLP Graph Structure:\n";
llvm::dbgs() << "===================\n";
@@ -348,7 +348,12 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
}
/// Build the SLP graph starting from memory operation groups
-static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
+ if (rootGroups.empty())
+ return SLPGraph();
+
+ LLVM_DEBUG(llvm::dbgs() << "=== Building SLP graph from " << rootGroups.size()
+ << " root groups ===\n");
SLPGraph graph;
llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
@@ -365,61 +370,74 @@ static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
LLVM_DEBUG({
llvm::dbgs() << "Created root group node with " << node->ops.size()
<< " operations of type "
- << (group.type == MemoryOpGroup::Type::Load ? "Load"
- : "Store")
- << "\n";
+ << (group.isLoadGroup() ? "Load" : "Store") << "\n";
});
+ }
- OperationsFingerprint fingerprints(opToNode);
-
- auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
- Operation *user = use.getOwner();
- if (opToNode.contains(user) || !isVectorizable(user))
- return;
+ OperationsFingerprint fingerprints(opToNode);
+
+ auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
+ Operation *user = use.getOwner();
+ auto it = opToNode.find(user);
+ if (it != opToNode.end()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " Adding edge from " << node->ops.front()->getName()
+ << " to " << it->first->getName() << "\n");
+ graph.addEdge(node, it->second);
+ return;
+ }
- Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
+ if (!isVectorizable(user))
+ return;
- SmallVector<Operation *> currentOps;
- currentOps.emplace_back(user);
- for (Operation *op : ArrayRef(node->ops).drop_front()) {
- Operation *found = nullptr;
- for (OpOperand &opUse : op->getUses()) {
- if (opUse.getOperandNumber() != use.getOperandNumber())
- continue;
+ Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
- Operation *useOwner = opUse.getOwner();
- if (!isEquivalent(useOwner, user) ||
- fingerprints.getFingerprint(useOwner) != expectedFingerprint)
- continue;
+ SmallVector<Operation *> currentOps;
+ currentOps.emplace_back(user);
+ for (Operation *op : ArrayRef(node->ops).drop_front()) {
+ Operation *found = nullptr;
+ for (OpOperand &opUse : op->getUses()) {
+ if (opUse.getOperandNumber() != use.getOperandNumber())
+ continue;
- found = useOwner;
- break;
- }
- if (!found)
- break;
+ Operation *useOwner = opUse.getOwner();
+ if (!isEquivalent(useOwner, user) ||
+ fingerprints.getFingerprint(useOwner) != expectedFingerprint)
+ continue;
- currentOps.push_back(found);
+ found = useOwner;
+ break;
}
+ if (!found)
+ break;
- if (currentOps.size() == 1)
- return;
+ currentOps.push_back(found);
+ }
- auto *newNode = graph.addNode(currentOps);
- graph.addEdge(node, newNode);
- for (Operation *op : currentOps) {
- opToNode[op] = newNode;
- fingerprints.invalidate(op);
- }
+ if (currentOps.size() == 1)
+ return;
- worklist.push_back(newNode);
- };
+ auto *newNode = graph.addNode(currentOps);
+ graph.addEdge(node, newNode);
+ for (Operation *op : currentOps) {
+ opToNode[op] = newNode;
+ fingerprints.invalidate(op);
+ }
- while (!worklist.empty()) {
- SLPGraphNode *node = worklist.pop_back_val();
+ worklist.push_back(newNode);
+ };
+
+ while (!worklist.empty()) {
+ SLPGraphNode *node = worklist.pop_back_val();
+ LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->ops.size()
+ << " operations, first op: "
+ << node->ops.front()->getName() << "\n");
- Operation *op = node->ops.front();
- for (OpOperand &use : op->getUses())
- processUse(node, use);
+ Operation *op = node->ops.front();
+ for (OpOperand &use : op->getUses()) {
+ processUse(node, use);
+ LLVM_DEBUG(llvm::dbgs() << " Processing use in operation: "
+ << use.getOwner()->getName() << "\n");
}
}
>From 3272e0192e6c6a4c2abd6e170ffa79ede3aa07da Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 14:05:07 +0200
Subject: [PATCH 11/27] refac
---
.../Vector/Transforms/SLPVectorizer.cpp | 41 ++++++++++---------
1 file changed, 22 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 28c53efea7512..e3b39ba10373c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -159,7 +159,10 @@ class SLPGraph {
/// Add a new node to the graph
SLPGraphNode *addNode(ArrayRef<Operation *> operations) {
nodes.push_back(std::make_unique<SLPGraphNode>(operations));
- return nodes.back().get();
+ auto *node = nodes.back().get();
+ for (Operation *op : operations)
+ opToNode[op] = node;
+ return node;
}
/// Add a root node (memory operation)
@@ -184,6 +187,12 @@ class SLPGraph {
return roots;
}
+ /// Get the node associated with an operation
+ SLPGraphNode *getNodeForOp(Operation *op) const {
+ auto it = opToNode.find(op);
+ return it != opToNode.end() ? it->second : nullptr;
+ }
+
/// Print the graph structure
[[maybe_unused]] void print() const {
llvm::dbgs() << "SLP Graph Structure:\n";
@@ -243,6 +252,7 @@ class SLPGraph {
private:
SmallVector<std::unique_ptr<SLPGraphNode>> nodes;
+ llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
};
/// This pass implements the SLP vectorizer. It detects consecutive operations
@@ -272,9 +282,7 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
}
struct OperationsFingerprint {
- OperationsFingerprint(
- const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode)
- : opToNode(opToNode) {}
+ OperationsFingerprint(const SLPGraph &graph) : graph(graph) {}
Fingerprint getFingerprint(Operation *op) {
auto it = fingerprints.find(op);
@@ -287,7 +295,7 @@ struct OperationsFingerprint {
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
toposortedOps.emplace_back(op);
- if (opToNode.contains(op))
+ if (graph.getNodeForOp(op))
continue;
for (Value operand : op->getOperands()) {
@@ -310,9 +318,9 @@ struct OperationsFingerprint {
if (!defOp)
continue;
- auto it1 = opToNode.find(defOp);
- if (it1 != opToNode.end()) {
- addDataToHash(hasher, it1->second);
+ auto *node = graph.getNodeForOp(defOp);
+ if (node) {
+ addDataToHash(hasher, node);
continue;
}
@@ -333,7 +341,7 @@ struct OperationsFingerprint {
fingerprints.clear();
}
- const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode;
+ const SLPGraph &graph;
DenseMap<Operation *, Fingerprint> fingerprints;
};
@@ -355,16 +363,12 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
LLVM_DEBUG(llvm::dbgs() << "=== Building SLP graph from " << rootGroups.size()
<< " root groups ===\n");
SLPGraph graph;
- llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
SmallVector<SLPGraphNode *> worklist;
// First, create nodes for each contiguous memory operation group
for (const auto &group : rootGroups) {
auto *node = graph.addRoot(group.ops);
- for (Operation *op : group.ops)
- opToNode[op] = node;
-
worklist.push_back(node);
LLVM_DEBUG({
@@ -374,16 +378,16 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
});
}
- OperationsFingerprint fingerprints(opToNode);
+ OperationsFingerprint fingerprints(graph);
auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
Operation *user = use.getOwner();
- auto it = opToNode.find(user);
- if (it != opToNode.end()) {
+ auto *existingNode = graph.getNodeForOp(user);
+ if (existingNode) {
LLVM_DEBUG(llvm::dbgs()
<< " Adding edge from " << node->ops.front()->getName()
- << " to " << it->first->getName() << "\n");
- graph.addEdge(node, it->second);
+ << " to " << user->getName() << "\n");
+ graph.addEdge(node, existingNode);
return;
}
@@ -420,7 +424,6 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
auto *newNode = graph.addNode(currentOps);
graph.addEdge(node, newNode);
for (Operation *op : currentOps) {
- opToNode[op] = newNode;
fingerprints.invalidate(op);
}
>From 04b2316c9859d734af8f968072eb298bbf96cfd9 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 14:59:20 +0200
Subject: [PATCH 12/27] toposort
---
.../Vector/Transforms/SLPVectorizer.cpp | 89 ++++++++++++++++++-
1 file changed, 85 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e3b39ba10373c..8f0137a12d07b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -135,8 +135,8 @@ extractContiguousGroups(const MemoryOpGroup &group) {
/// A node in the SLP graph representing a group of vectorizable operations
struct SLPGraphNode {
SmallVector<Operation *> ops;
- llvm::SmallDenseSet<SLPGraphNode *> users;
- llvm::SmallDenseSet<SLPGraphNode *> operands;
+ SmallVector<SLPGraphNode *> users;
+ SmallVector<SLPGraphNode *> operands;
bool isRoot = false;
SLPGraphNode() = default;
@@ -174,8 +174,8 @@ class SLPGraph {
/// Add a dependency edge between nodes
void addEdge(SLPGraphNode *from, SLPGraphNode *to) {
- from->users.insert(to);
- to->operands.insert(from);
+ from->users.push_back(to);
+ to->operands.push_back(from);
}
/// Get all root nodes
@@ -193,6 +193,80 @@ class SLPGraph {
return it != opToNode.end() ? it->second : nullptr;
}
+ /// Topologically sort the nodes in the graph
+ SmallVector<SLPGraphNode *> topologicalSort() const {
+ SmallVector<SLPGraphNode *> result;
+ llvm::SmallDenseSet<SLPGraphNode *> visited;
+
+ SmallVector<SLPGraphNode *> stack;
+
+ // Process each node
+ for (const auto &node : nodes) {
+ if (visited.contains(node.get()))
+ continue;
+
+ stack.emplace_back(node.get());
+ while (!stack.empty()) {
+ SLPGraphNode *node = stack.pop_back_val();
+ if (visited.contains(node))
+ continue;
+
+ stack.push_back(node);
+
+ bool pushed = false;
+ for (SLPGraphNode *operand : node->operands) {
+ if (visited.contains(operand))
+ continue;
+
+ stack.push_back(operand);
+ pushed = true;
+ }
+
+ if (!pushed) {
+ visited.insert(node);
+ result.push_back(node);
+ }
+ }
+ }
+
+ return result;
+ }
+
+ /// Vectorize the operations in the graph
+ LogicalResult vectorize(IRRewriter &rewriter) {
+ if (nodes.empty())
+ return success();
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "Vectorizing SLP graph with " << nodes.size() << " nodes\n");
+
+ // Get topologically sorted nodes
+ SmallVector<SLPGraphNode *> sortedNodes = topologicalSort();
+ if (sortedNodes.empty()) {
+ LLVM_DEBUG(llvm::dbgs() << "Failed to topologically sort nodes\n");
+ return failure();
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Topologically sorted nodes:\n";
+ for (auto *node : sortedNodes) {
+ llvm::dbgs() << " Node with " << node->ops.size()
+ << " operations: " << node->ops.front()->getName() << "\n";
+ }
+ });
+
+ // TODO: Implement vectorization logic:
+ // 1. Process nodes in topological order
+ // 2. For each node:
+ // a. Check if all operands are vectorized
+ // b. Create vector operation
+ // c. Replace scalar operations with vector operation
+ // 3. Handle memory operations (loads/stores) specially
+ // 4. Update use-def chains
+
+ return success();
+ }
+
/// Print the graph structure
[[maybe_unused]] void print() const {
llvm::dbgs() << "SLP Graph Structure:\n";
@@ -513,6 +587,13 @@ void SLPVectorizerPass::runOnOperation() {
// Print the graph structure
LLVM_DEBUG(graph.print());
+
+ // Vectorize the graph
+ IRRewriter rewriter(&getContext());
+ if (failed(graph.vectorize(rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
+ return signalPassFailure();
+ }
});
}
>From 740011b4c25913e89015076e84b583221c6c047d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:30:01 +0200
Subject: [PATCH 13/27] codegen
---
.../Vector/Transforms/SLPVectorizer.cpp | 68 +++++++++++++++----
1 file changed, 56 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 8f0137a12d07b..095ad4f11a91a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -132,6 +132,10 @@ extractContiguousGroups(const MemoryOpGroup &group) {
return result;
}
+static bool isVectorizable(Operation *op) {
+ return OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1;
+}
+
/// A node in the SLP graph representing a group of vectorizable operations
struct SLPGraphNode {
SmallVector<Operation *> ops;
@@ -255,14 +259,58 @@ class SLPGraph {
}
});
- // TODO: Implement vectorization logic:
- // 1. Process nodes in topological order
- // 2. For each node:
- // a. Check if all operands are vectorized
- // b. Create vector operation
- // c. Replace scalar operations with vector operation
- // 3. Handle memory operations (loads/stores) specially
- // 4. Update use-def chains
+ IRMapping mapping;
+ for (auto *node : sortedNodes) {
+ if (node->users.empty() && node->operands.empty())
+ continue;
+
+ Operation *op = node->ops.front();
+ rewriter.setInsertionPoint(op);
+ Location loc = op->getLoc();
+ int64_t numElements = node->ops.size();
+ if (auto load = dyn_cast<memref::LoadOp>(op)) {
+ auto vecType =
+ VectorType::get(numElements, load.getMemRefType().getElementType());
+ Value result = rewriter.create<vector::LoadOp>(
+ loc, vecType, load.getMemRef(), load.getIndices());
+ mapping.map(load.getResult(), result);
+ } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
+ Value val = mapping.lookupOrDefault(store.getValueToStore());
+ rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
+ store.getIndices());
+ } else if (isVectorizable(op)) {
+ auto vecType =
+ VectorType::get(numElements, op->getResultTypes().front());
+ for (auto [i, operand] : llvm::enumerate(op->getOperands())) {
+ if (getNodeForOp(operand.getDefiningOp()))
+ continue;
+
+ SmallVector<Value> args;
+ for (Operation *defOp : node->ops)
+ args.push_back(defOp->getOperand(i));
+
+ Value vector =
+ rewriter.create<vector::FromElementsOp>(loc, vecType, args);
+ mapping.map(operand, vector);
+ }
+
+ Operation *newOp = rewriter.clone(*op, mapping);
+ auto resVectorType =
+ VectorType::get(numElements, op->getResultTypes().front());
+ newOp->getResult(0).setType(resVectorType);
+
+ mapping.map(op->getResults(), newOp->getResults());
+ } else {
+ op->emitError("unsupported operation");
+ return failure();
+ }
+ }
+
+ for (auto *node : llvm::reverse(sortedNodes)) {
+ for (Operation *op : node->ops) {
+ rewriter.eraseOp(op);
+ }
+ }
return success();
}
@@ -343,10 +391,6 @@ struct SLPVectorizerPass
SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
};
-static bool isVectorizable(Operation *op) {
- return OpTrait::hasElementwiseMappableTraits(op);
-}
-
using Fingerprint = std::array<uint8_t, 20>;
template <typename T>
>From bbd1122dd97f7719b3a75329fae58d8c72916f90 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:53:41 +0200
Subject: [PATCH 14/27] test
---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index a07dd05dd16aa..266008e53ea43 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,10 +1,13 @@
// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
-// CHECK-LABEL: func @basic_slp
-func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
- // CHECK: vector.transfer_read
- // CHECK: arith.addi
- // CHECK: vector.transfer_write
+// CHECK-LABEL: func @read_read_add_write
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
>From bcbb729dea49391c7f0b0ead2e23198e6fa0b816 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:57:06 +0200
Subject: [PATCH 15/27] test
---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 25 +++++++++++++++++++++
1 file changed, 25 insertions(+)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 266008e53ea43..28a255f90a869 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,5 +1,30 @@
// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
+// CHECK-LABEL: func @read_write
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ memref.store %0, %arg0[%c0] : memref<8xi32>
+ memref.store %1, %arg0[%c1] : memref<8xi32>
+ memref.store %2, %arg0[%c2] : memref<8xi32>
+ memref.store %3, %arg0[%c3] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
>From 1f68586ad924bf8553438b40edfa6c44e2e2b017 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 16:25:45 +0200
Subject: [PATCH 16/27] fixes
---
.../Vector/Transforms/SLPVectorizer.cpp | 49 +++++++++++++------
mlir/test/Dialect/Vector/slp-vectorize.mlir | 36 ++++++++++++++
2 files changed, 70 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 095ad4f11a91a..a40131a1b10ff 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -264,24 +264,13 @@ class SLPGraph {
if (node->users.empty() && node->operands.empty())
continue;
+ int64_t numElements = node->ops.size();
Operation *op = node->ops.front();
rewriter.setInsertionPoint(op);
Location loc = op->getLoc();
- int64_t numElements = node->ops.size();
- if (auto load = dyn_cast<memref::LoadOp>(op)) {
- auto vecType =
- VectorType::get(numElements, load.getMemRefType().getElementType());
- Value result = rewriter.create<vector::LoadOp>(
- loc, vecType, load.getMemRef(), load.getIndices());
- mapping.map(load.getResult(), result);
- } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
- Value val = mapping.lookupOrDefault(store.getValueToStore());
- rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
- store.getIndices());
- } else if (isVectorizable(op)) {
- auto vecType =
- VectorType::get(numElements, op->getResultTypes().front());
- for (auto [i, operand] : llvm::enumerate(op->getOperands())) {
+
+ auto handleNonVectorInputs = [&](ValueRange operands) {
+ for (auto [i, operand] : llvm::enumerate(operands)) {
if (getNodeForOp(operand.getDefiningOp()))
continue;
@@ -289,17 +278,47 @@ class SLPGraph {
for (Operation *defOp : node->ops)
args.push_back(defOp->getOperand(i));
+ auto vecType = VectorType::get(numElements, operand.getType());
Value vector =
rewriter.create<vector::FromElementsOp>(loc, vecType, args);
mapping.map(operand, vector);
}
+ };
+
+ auto handleNonVectorOutputs = [&](Value newResult) {
+ for (auto [i, result] : llvm::enumerate(node->ops)) {
+ for (OpOperand &use : result->getUses()) {
+ Operation *useOwner = use.getOwner();
+ if (getNodeForOp(useOwner))
+ continue;
+
+ Value elem = rewriter.create<vector::ExtractOp>(loc, newResult, i);
+ use.set(elem);
+ }
+ }
+ };
+ if (auto load = dyn_cast<memref::LoadOp>(op)) {
+ auto vecType =
+ VectorType::get(numElements, load.getMemRefType().getElementType());
+ Value result = rewriter.create<vector::LoadOp>(
+ loc, vecType, load.getMemRef(), load.getIndices());
+ mapping.map(load.getResult(), result);
+ handleNonVectorOutputs(result);
+ } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
+ handleNonVectorInputs(store.getValueToStore());
+ Value val = mapping.lookupOrDefault(store.getValueToStore());
+ rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
+ store.getIndices());
+ } else if (isVectorizable(op)) {
+ handleNonVectorInputs(op->getOperands());
Operation *newOp = rewriter.clone(*op, mapping);
auto resVectorType =
VectorType::get(numElements, op->getResultTypes().front());
newOp->getResult(0).setType(resVectorType);
mapping.map(op->getResults(), newOp->getResults());
+ handleNonVectorOutputs(newOp->getResult(0));
} else {
op->emitError("unsupported operation");
return failure();
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 28a255f90a869..2b2b91d667e00 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
+
// CHECK-LABEL: func @read_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
@@ -24,6 +25,41 @@ func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
return
}
+// CHECK-LABEL: func @read_read_add
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: %[[R0:.*]] = vector.extract %[[RES]][0] : i32 from vector<4xi32>
+ // CHECK: %[[R1:.*]] = vector.extract %[[RES]][1] : i32 from vector<4xi32>
+ // CHECK: %[[R2:.*]] = vector.extract %[[RES]][2] : i32 from vector<4xi32>
+ // CHECK: %[[R3:.*]] = vector.extract %[[RES]][3] : i32 from vector<4xi32>
+ // CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]] : i32, i32, i32, i32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+ %8 = arith.addi %0, %4 : i32
+ %9 = arith.addi %1, %5 : i32
+ %10 = arith.addi %2, %6 : i32
+ %11 = arith.addi %3, %7 : i32
+
+ return %8, %9, %10, %11 : i32, i32, i32, i32
+}
+
// CHECK-LABEL: func @read_read_add_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
>From 7b24debf6d7e8e095dc60fc5b3bbcbbce19e27dd Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 17:41:39 +0200
Subject: [PATCH 17/27] fixes
---
.../Vector/Transforms/SLPVectorizer.cpp | 56 +++++++++++++++++--
mlir/test/Dialect/Vector/slp-vectorize.mlir | 29 ++++++++++
2 files changed, 79 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index a40131a1b10ff..ab0b3f549192f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -259,9 +259,13 @@ class SLPGraph {
}
});
+ auto isGoodNode = [&](SLPGraphNode *node) {
+ return node->users.empty() && node->operands.empty();
+ };
+
IRMapping mapping;
for (auto *node : sortedNodes) {
- if (node->users.empty() && node->operands.empty())
+ if (isGoodNode(node))
continue;
int64_t numElements = node->ops.size();
@@ -326,6 +330,9 @@ class SLPGraph {
}
for (auto *node : llvm::reverse(sortedNodes)) {
+ if (isGoodNode(node))
+ continue;
+
for (Operation *op : node->ops) {
rewriter.eraseOp(op);
}
@@ -560,10 +567,47 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
auto *newNode = graph.addNode(currentOps);
graph.addEdge(node, newNode);
- for (Operation *op : currentOps) {
+ for (Operation *op : currentOps)
fingerprints.invalidate(op);
+
+ worklist.push_back(newNode);
+ };
+
+ auto processOperands = [&](SLPGraphNode *node, Value operand, int64_t index) {
+ Operation *srcOp = operand.getDefiningOp();
+ if (!srcOp)
+ return;
+
+ auto *existingNode = graph.getNodeForOp(srcOp);
+ if (existingNode) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " Adding edge from " << srcOp->getName() << " to "
+ << node->ops.front()->getName() << "\n");
+ graph.addEdge(existingNode, node);
+ return;
+ }
+
+ if (!isVectorizable(srcOp))
+ return;
+
+ SmallVector<Operation *> currentOps;
+ currentOps.emplace_back(srcOp);
+ for (Operation *op : ArrayRef(node->ops).drop_front()) {
+ Operation *otherOp = op->getOperand(index).getDefiningOp();
+ if (!otherOp || !isEquivalent(otherOp, srcOp))
+ break;
+
+ currentOps.push_back(otherOp);
}
+ if (currentOps.size() == 1)
+ return;
+
+ auto *newNode = graph.addNode(currentOps);
+ graph.addEdge(newNode, node);
+ for (Operation *op : currentOps)
+ fingerprints.invalidate(op);
+
worklist.push_back(newNode);
};
@@ -574,11 +618,11 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
<< node->ops.front()->getName() << "\n");
Operation *op = node->ops.front();
- for (OpOperand &use : op->getUses()) {
+ for (OpOperand &use : op->getUses())
processUse(node, use);
- LLVM_DEBUG(llvm::dbgs() << " Processing use in operation: "
- << use.getOwner()->getName() << "\n");
- }
+
+ for (auto [i, operand] : llvm::enumerate(op->getOperands()))
+ processOperands(node, operand, i);
}
return graph;
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 2b2b91d667e00..036e1fcbed5d5 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -60,6 +60,35 @@ func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i3
return %8, %9, %10, %11 : i32, i32, i32, i32
}
+// CHECK-LABEL: func @add_write
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: memref<8xi32>)
+func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
+ %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32,
+ %arg8: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] : vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.from_elements %[[ARG4]], %[[ARG5]], %[[ARG6]], %[[ARG7]] : vector<4xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %0, %1 : vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG8]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %8 = arith.addi %arg0, %arg4 : i32
+ %9 = arith.addi %arg1, %arg5 : i32
+ %10 = arith.addi %arg2, %arg6 : i32
+ %11 = arith.addi %arg3, %arg7 : i32
+
+ memref.store %8, %arg8[%c0] : memref<8xi32>
+ memref.store %9, %arg8[%c1] : memref<8xi32>
+ memref.store %10, %arg8[%c2] : memref<8xi32>
+ memref.store %11, %arg8[%c3] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
>From 5433fd954da2764b8ed92de5e42cac0d903f125a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:07:37 +0200
Subject: [PATCH 18/27] handle size mismatch
---
.../Vector/Transforms/SLPVectorizer.cpp | 21 +++++++
mlir/test/Dialect/Vector/slp-vectorize.mlir | 62 ++++++++++++++++++-
2 files changed, 82 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ab0b3f549192f..f54a9aba0e6c0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -302,6 +302,16 @@ class SLPGraph {
}
};
+ auto handleVecSizeMismatch = [&](Value arg) -> Value {
+ auto srcType = cast<VectorType>(arg.getType());
+ assert(srcType.getRank() == 1);
+ if (srcType.getDimSize(0) == numElements)
+ return arg;
+
+ return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, 0,
+ numElements, 1);
+ };
+
if (auto load = dyn_cast<memref::LoadOp>(op)) {
auto vecType =
VectorType::get(numElements, load.getMemRefType().getElementType());
@@ -312,6 +322,7 @@ class SLPGraph {
} else if (auto store = dyn_cast<memref::StoreOp>(op)) {
handleNonVectorInputs(store.getValueToStore());
Value val = mapping.lookupOrDefault(store.getValueToStore());
+ val = handleVecSizeMismatch(val);
rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
store.getIndices());
} else if (isVectorizable(op)) {
@@ -319,6 +330,15 @@ class SLPGraph {
Operation *newOp = rewriter.clone(*op, mapping);
auto resVectorType =
VectorType::get(numElements, op->getResultTypes().front());
+
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(newOp);
+ for (OpOperand &operand : newOp->getOpOperands()) {
+ Value newOperand = handleVecSizeMismatch(operand.get());
+ operand.set(newOperand);
+ }
+ }
newOp->getResult(0).setType(resVectorType);
mapping.map(op->getResults(), newOp->getResults());
@@ -701,6 +721,7 @@ void SLPVectorizerPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
return signalPassFailure();
}
+ op->dump();
});
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 036e1fcbed5d5..76592833a78b4 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -25,6 +25,31 @@ func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
return
}
+
+// CHECK-LABEL: func @read_write_size_mistamtch
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write_size_mistamtch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[RES1:.*]] = vector.extract_strided_slice %[[RES]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: vector.store %[[RES1]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ memref.store %0, %arg0[%c0] : memref<8xi32>
+ memref.store %1, %arg0[%c1] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
@@ -60,6 +85,7 @@ func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i3
return %8, %9, %10, %11 : i32, i32, i32, i32
}
+
// CHECK-LABEL: func @add_write
// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: memref<8xi32>)
func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
@@ -89,7 +115,6 @@ func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
}
-
// CHECK-LABEL: func @read_read_add_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
@@ -125,3 +150,38 @@ func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
return
}
+
+
+// CHECK-LABEL: func @read_read_add_write_size_mismatch
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[A1:.*]] = vector.extract_strided_slice %[[A]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[B1:.*]] = vector.extract_strided_slice %[[B]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %[[A1]], %[[B1]] : vector<2xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+ %8 = arith.addi %0, %4 : i32
+ %9 = arith.addi %1, %5 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+
+ return
+}
>From 2f02d807ac75c95770d1ff72a082c69024616f2c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:53:54 +0200
Subject: [PATCH 19/27] adjacent indices
---
.../Vector/Transforms/SLPVectorizer.cpp | 98 ++++++++++---------
mlir/test/Dialect/Vector/slp-vectorize.mlir | 25 +++++
2 files changed, 79 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index f54a9aba0e6c0..cc252a0e32c06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,16 +52,43 @@ struct MemoryOpGroup {
bool empty() const { return ops.empty(); }
};
-// Helper function to extract base and index from a memory operation
-std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
- if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
- if (auto value = getConstantIntValue(loadOp.getIndices().front()))
- return std::make_pair(loadOp.getMemRef(), *value);
- } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
- if (auto value = getConstantIntValue(storeOp.getIndices().front()))
- return std::make_pair(storeOp.getMemRef(), *value);
+static ValueRange getIndices(Operation *op) {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getIndices();
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getIndices();
+ return {};
+}
+
+static Type getElementType(Operation *op) {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getResult().getType();
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getValueToStore().getType();
+ return {};
+}
+
+static bool isAdjacentIndices(Value idx1, Value idx2) {
+ if (auto c1 = getConstantIntValue(idx1)) {
+ if (auto c2 = getConstantIntValue(idx2))
+ return *c1 + 1 == *c2;
}
- return std::nullopt;
+ return false;
+}
+
+static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
+ if (idx1.empty() || idx1.size() != idx2.size())
+ return false;
+
+ if (idx1.drop_back() != idx2.drop_back())
+ return false;
+
+ return isAdjacentIndices(idx1.back(), idx2.back());
+}
+
+static bool isAdjacentIndices(Operation *op1, Operation *op2) {
+ return getElementType(op1) == getElementType(op2) &&
+ isAdjacentIndices(getIndices(op1), getIndices(op2));
}
// Extract contiguous groups from a MemoryOpGroup
@@ -71,64 +98,48 @@ extractContiguousGroups(const MemoryOpGroup &group) {
if (group.ops.empty())
return result;
- // Keep track of which operations we've processed
- DenseSet<Operation *> processedOps;
+ llvm::SmallDenseSet<Operation *> processedOps;
- // Process each operation
for (Operation *op : group.ops) {
- // Skip if we've already processed this operation
if (processedOps.contains(op))
continue;
- // Get base and index of current operation
- auto baseAndIndex = getBaseAndIndex(op);
- if (!baseAndIndex)
- continue;
-
- auto [base, index] = *baseAndIndex;
-
// Start a new group with this operation
result.emplace_back(group.type);
MemoryOpGroup ¤tGroup = result.back();
- currentGroup.ops.push_back(op);
+ auto ¤tOps = currentGroup.ops;
+ currentOps.push_back(op);
processedOps.insert(op);
- LLVM_DEBUG(llvm::dbgs() << "Starting new group at base " << base
- << " index " << index << "\n");
-
- // Try to find operations with adjacent indices
bool foundMore;
do {
foundMore = false;
- // Look for operations with index+1
for (Operation *otherOp : group.ops) {
if (processedOps.contains(otherOp))
continue;
- auto otherBaseAndIndex = getBaseAndIndex(otherOp);
- if (!otherBaseAndIndex)
- continue;
-
- auto [otherBase, otherIndex] = *otherBaseAndIndex;
-
- // Check if this operation has the same base and adjacent index
- if (otherBase == base && otherIndex == currentGroup.ops.size()) {
- currentGroup.ops.push_back(otherOp);
+ Operation *firstOp = currentOps.front();
+ Operation *lastOp = currentOps.back();
+ if (isAdjacentIndices(otherOp, firstOp)) {
+ currentOps.insert(currentOps.begin(), otherOp);
+ processedOps.insert(otherOp);
+ foundMore = true;
+ } else if (isAdjacentIndices(lastOp, otherOp)) {
+ currentOps.push_back(otherOp);
processedOps.insert(otherOp);
foundMore = true;
- LLVM_DEBUG(llvm::dbgs()
- << "Added operation with index " << otherIndex << "\n");
- break;
}
}
} while (foundMore);
- }
- // Remove empty groups
- result.erase(std::remove_if(result.begin(), result.end(),
- [](const MemoryOpGroup &g) { return g.empty(); }),
- result.end());
+ if (currentOps.size() <= 1) {
+ result.pop_back();
+ continue;
+ }
+ LLVM_DEBUG(llvm::dbgs() << "Extracted contiguous group with "
+ << currentGroup.ops.size() << " operations\n");
+ }
return result;
}
@@ -721,7 +732,6 @@ void SLPVectorizerPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
return signalPassFailure();
}
- op->dump();
});
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 76592833a78b4..6be405ad078b9 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -50,6 +50,31 @@ func.func @read_write_size_mistamtch(%arg0: memref<8xi32>, %arg1: memref<8xi32>)
}
+// CHECK-LABEL: func @read_write_interleaved
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+
+ memref.store %1, %arg0[%c1] : memref<8xi32>
+ memref.store %0, %arg0[%c0] : memref<8xi32>
+ memref.store %3, %arg0[%c3] : memref<8xi32>
+ memref.store %2, %arg0[%c2] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
>From 019f5614f606a9f5d031367de77de692819b0efc Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:57:52 +0200
Subject: [PATCH 20/27] test
---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 38 +++++++++++++++++++++
1 file changed, 38 insertions(+)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 6be405ad078b9..9c5005f807c71 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -210,3 +210,41 @@ func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref
return
}
+
+
+// CHECK-LABEL: func @read_read_add_write_interleaved
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+ %11 = arith.addi %3, %7 : i32
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %8 = arith.addi %0, %4 : i32
+
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %10 = arith.addi %2, %6 : i32
+
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %9 = arith.addi %1, %5 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+
+ return
+}
>From fc5d42c732bc9632123925e8d152d0fff23e8813 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:07:26 +0200
Subject: [PATCH 21/27] fixes and test
---
.../Vector/Transforms/SLPVectorizer.cpp | 11 +++-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 54 +++++++++++++++++++
2 files changed, 64 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index cc252a0e32c06..3ff46093d9fbe 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,6 +52,14 @@ struct MemoryOpGroup {
bool empty() const { return ops.empty(); }
};
+static Value getBase(Operation *op) {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getMemRef();
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getMemRef();
+ return {};
+}
+
static ValueRange getIndices(Operation *op) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getIndices();
@@ -87,7 +95,8 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
}
static bool isAdjacentIndices(Operation *op1, Operation *op2) {
- return getElementType(op1) == getElementType(op2) &&
+ return getBase(op1) == getBase(op2) &&
+ getElementType(op1) == getElementType(op2) &&
isAdjacentIndices(getIndices(op1), getIndices(op2));
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 9c5005f807c71..820fbf2d260cd 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -248,3 +248,57 @@ func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8
return
}
+
+
+// CHECK-LABEL: func @read_read_add_add_write
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>
+// CHECK-SAME: , %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
+ %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[ADD1:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: %[[C:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] : vector<4xi32>
+ // CHECK: %[[ADD2:.*]] = arith.addi %[[A]], %[[C]] : vector<4xi32>
+ // CHECK: vector.store %[[ADD1]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: vector.store %[[ADD2]], %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+ %8 = arith.addi %0, %4 : i32
+ %12 = arith.addi %0, %arg2 : i32
+
+ %13 = arith.addi %1, %arg3 : i32
+ %9 = arith.addi %1, %5 : i32
+
+ %10 = arith.addi %2, %6 : i32
+ %14 = arith.addi %2, %arg4 : i32
+
+ %15 = arith.addi %3, %arg5 : i32
+ %11 = arith.addi %3, %7 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
+
+ memref.store %12, %arg1[%c0] : memref<8xi32>
+ memref.store %13, %arg1[%c1] : memref<8xi32>
+ memref.store %14, %arg1[%c2] : memref<8xi32>
+ memref.store %15, %arg1[%c3] : memref<8xi32>
+
+ return
+}
>From e7e1172787b2697e7ad9860849b48a229d9e73bd Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:22:03 +0200
Subject: [PATCH 22/27] better side effects handling
---
.../Vector/Transforms/SLPVectorizer.cpp | 94 +++++++++++--------
1 file changed, 55 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3ff46093d9fbe..6cb6faa486702 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,6 +52,61 @@ struct MemoryOpGroup {
bool empty() const { return ops.empty(); }
};
+static bool isReadOp(Operation *op) {
+ auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
+ if (!effectInterface)
+ return true;
+
+ return effectInterface.hasEffect<MemoryEffects::Read>();
+}
+
+static bool isWriteOp(Operation *op) {
+ auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
+ if (!effectInterface)
+ return true;
+
+ return effectInterface.hasEffect<MemoryEffects::Write>();
+}
+
+/// Collect all memory operations in the block into groups.
+/// Each group contains either all loads or all stores, uninterrupted by
+/// operations of the other type.
+static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
+ SmallVector<MemoryOpGroup> groups;
+ MemoryOpGroup *currentGroup = nullptr;
+
+ for (Operation &op : block) {
+ if (currentGroup) {
+ if (currentGroup->isLoadGroup() && isWriteOp(&op)) {
+ currentGroup = nullptr;
+ } else if (currentGroup->isStoreGroup() && isReadOp(&op)) {
+ currentGroup = nullptr;
+ }
+ }
+
+ if (!isa<memref::LoadOp, memref::StoreOp>(op))
+ continue;
+
+ bool isLoad = isReadOp(&op);
+ MemoryOpGroup::Type type =
+ isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
+
+ if (!currentGroup) {
+ groups.emplace_back(type);
+ currentGroup = &groups.back();
+ }
+
+ currentGroup->ops.push_back(&op);
+ }
+
+ // Remove empty groups
+ groups.erase(std::remove_if(groups.begin(), groups.end(),
+ [](const MemoryOpGroup &g) { return g.empty(); }),
+ groups.end());
+
+ return groups;
+}
+
static Value getBase(Operation *op) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getMemRef();
@@ -449,12 +504,6 @@ class SLPGraph {
struct SLPVectorizerPass
: public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
void runOnOperation() override;
-
-private:
- /// Collect all memory operations in the block into groups.
- /// Each group contains either all loads or all stores, uninterrupted by
- /// operations of the other type.
- SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
};
using Fingerprint = std::array<uint8_t, 20>;
@@ -668,39 +717,6 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
return graph;
}
-SmallVector<MemoryOpGroup>
-SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
- SmallVector<MemoryOpGroup> groups;
- MemoryOpGroup *currentGroup = nullptr;
-
- for (Operation &op : block) {
- // Skip non-memory operations
- if (!isa<memref::LoadOp, memref::StoreOp>(op))
- continue;
-
- bool isLoad = isa<memref::LoadOp>(op);
- MemoryOpGroup::Type type =
- isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
-
- // Start a new group if:
- // 1. We don't have a current group, or
- // 2. The current operation is a different type than the current group
- if (!currentGroup || currentGroup->type != type) {
- groups.emplace_back(type);
- currentGroup = &groups.back();
- }
-
- currentGroup->ops.push_back(&op);
- }
-
- // Remove empty groups
- groups.erase(std::remove_if(groups.begin(), groups.end(),
- [](const MemoryOpGroup &g) { return g.empty(); }),
- groups.end());
-
- return groups;
-}
-
void SLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
>From ae187a0dada44517f56efe5eaa4ce981f8899dfc Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:44:47 +0200
Subject: [PATCH 23/27] cleanup
---
.../mlir/Dialect/Vector/Transforms/Passes.h | 3 --
.../mlir/Dialect/Vector/Transforms/Passes.td | 14 ++++--
.../Vector/Transforms/SLPVectorizer.cpp | 49 ++++++++++++-------
mlir/test/Dialect/Vector/slp-vectorize.mlir | 2 +-
4 files changed, 43 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 43112f084dc60..5667f4fa95ace 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -25,9 +25,6 @@ std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
VectorMultiReductionLowering option =
VectorMultiReductionLowering::InnerParallel);
-/// Creates a pass that implements the SLP vectorizer.
-std::unique_ptr<Pass> createSLPVectorizerPass();
-
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 94ccd61cb5170..d5c31c9f78409 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,15 +34,21 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
];
}
-def SLPVectorizer : Pass<"slp-vectorizer", "ModuleOp"> {
+def GreedySLPVectorizer : Pass<"greedy-slp-vectorizer"> {
let summary = "SLP Vectorizer Pass";
let description = [{
This pass implements the SLP (Superword Level Parallelism) vectorizer.
It detects consecutive operations that can be put together into vector
- operations. The pass works bottom-up, across basic blocks, in search of
- scalars to combine.
+ operations. The pass works bi-directionaly, starting from reads or stores,
+ in search of scalars to combine.
+
+ This is greedy vectorizer, it doesn't have any cost model (yet) and it tries
+ to create vector ops if we have at least 2 potential ops.
+
+ It doesn't check if target actually supports resulted vectors either, user
+ will need a follow up pass which will split large and/or unaliggned vectors
+ into sizes actually supported by the target.
}];
- let constructor = "mlir::vector::createSLPVectorizerPass()";
let dependentDialects = ["mlir::vector::VectorDialect"];
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 6cb6faa486702..d7c2dc3845cac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -27,7 +27,7 @@
namespace mlir {
namespace vector {
-#define GEN_PASS_DEF_SLPVECTORIZER
+#define GEN_PASS_DEF_GREEDYSLPVECTORIZER
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
} // namespace vector
} // namespace mlir
@@ -115,6 +115,19 @@ static Value getBase(Operation *op) {
return {};
}
+static bool isContiguousLastDim(Value val) {
+ auto memrefType = dyn_cast<MemRefType>(val.getType());
+ if (!memrefType)
+ return false;
+
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ if (failed(memrefType.getStridesAndOffset(strides, offset)))
+ return false;
+
+ return !strides.empty() && strides.back() == 1;
+}
+
static ValueRange getIndices(Operation *op) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getIndices();
@@ -150,8 +163,15 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
}
static bool isAdjacentIndices(Operation *op1, Operation *op2) {
- return getBase(op1) == getBase(op2) &&
- getElementType(op1) == getElementType(op2) &&
+ Value base1 = getBase(op1);
+ Value base2 = getBase(op2);
+ if (base1 != base2)
+ return false;
+
+ if (!isContiguousLastDim(base1))
+ return false;
+
+ return getElementType(op1) == getElementType(op2) &&
isAdjacentIndices(getIndices(op1), getIndices(op2));
}
@@ -498,11 +518,9 @@ class SLPGraph {
llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
};
-/// This pass implements the SLP vectorizer. It detects consecutive operations
-/// that can be put together into vector operations. The pass works bottom-up,
-/// across basic blocks, in search of scalars to combine.
-struct SLPVectorizerPass
- : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
+struct GreedySLPVectorizerPass
+ : public mlir::vector::impl::GreedySLPVectorizerBase<
+ GreedySLPVectorizerPass> {
void runOnOperation() override;
};
@@ -717,11 +735,11 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
return graph;
}
-void SLPVectorizerPass::runOnOperation() {
+void GreedySLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
// Walk all blocks recursively
- op->walk([&](Block *block) {
+ op->walk([&](Block *block) -> WalkResult {
LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
<< block->getParentOp()->getName() << "\n");
@@ -747,21 +765,18 @@ void SLPVectorizerPass::runOnOperation() {
// Build the SLP graph from root groups
SLPGraph graph = buildSLPGraph(rootGroups);
-
- // Print the graph structure
LLVM_DEBUG(graph.print());
// Vectorize the graph
IRRewriter rewriter(&getContext());
if (failed(graph.vectorize(rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
- return signalPassFailure();
+ signalPassFailure();
+ return WalkResult::interrupt();
}
+
+ return WalkResult::advance();
});
}
} // namespace
-
-std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
- return std::make_unique<SLPVectorizerPass>();
-}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 820fbf2d260cd..2e9298d11ed05 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
+// RUN: mlir-opt %s --greedy-slp-vectorizer | FileCheck %s
// CHECK-LABEL: func @read_write
>From de5e898a81c1363acef5f5b00b9d9c254fa2554b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 20:10:02 +0200
Subject: [PATCH 24/27] cleanup
---
.../Vector/Transforms/SLPVectorizer.cpp | 80 +++++++++++++------
1 file changed, 57 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index d7c2dc3845cac..24059ec355b30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -49,7 +49,6 @@ struct MemoryOpGroup {
bool isStoreGroup() const { return type == Type::Store; }
size_t size() const { return ops.size(); }
- bool empty() const { return ops.empty(); }
};
static bool isReadOp(Operation *op) {
@@ -99,11 +98,6 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
currentGroup->ops.push_back(&op);
}
- // Remove empty groups
- groups.erase(std::remove_if(groups.begin(), groups.end(),
- [](const MemoryOpGroup &g) { return g.empty(); }),
- groups.end());
-
return groups;
}
@@ -144,14 +138,19 @@ static Type getElementType(Operation *op) {
return {};
}
+/// Check if two indices are consecutive, i.e fastest index differs by 1.
static bool isAdjacentIndices(Value idx1, Value idx2) {
if (auto c1 = getConstantIntValue(idx1)) {
if (auto c2 = getConstantIntValue(idx2))
return *c1 + 1 == *c2;
}
+
+ // TODO: Check arith.add, affine.apply, etc
return false;
}
+/// Check if two ranges of indices are consecutive, i.e fastest index differs
+/// by 1 and all other indices are the same.
static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
if (idx1.empty() || idx1.size() != idx2.size())
return false;
@@ -162,7 +161,10 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
return isAdjacentIndices(idx1.back(), idx2.back());
}
-static bool isAdjacentIndices(Operation *op1, Operation *op2) {
+/// Check if two operations are adjacent and can be combined into a vector op.
+/// This is done by checking if the base memrefs are the same, the last
+/// dimension is contiguous, and the element types and indices are compatible
+static bool isAdjacentOps(Operation *op1, Operation *op2) {
Value base1 = getBase(op1);
Value base2 = getBase(op2);
if (base1 != base2)
@@ -195,6 +197,8 @@ extractContiguousGroups(const MemoryOpGroup &group) {
currentOps.push_back(op);
processedOps.insert(op);
+ // Keep adding ops to the beginning or end of the current group until no
+ // more ops can be added.
bool foundMore;
do {
foundMore = false;
@@ -204,11 +208,11 @@ extractContiguousGroups(const MemoryOpGroup &group) {
Operation *firstOp = currentOps.front();
Operation *lastOp = currentOps.back();
- if (isAdjacentIndices(otherOp, firstOp)) {
+ if (isAdjacentOps(otherOp, firstOp)) {
currentOps.insert(currentOps.begin(), otherOp);
processedOps.insert(otherOp);
foundMore = true;
- } else if (isAdjacentIndices(lastOp, otherOp)) {
+ } else if (isAdjacentOps(lastOp, otherOp)) {
currentOps.push_back(otherOp);
processedOps.insert(otherOp);
foundMore = true;
@@ -222,7 +226,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
}
LLVM_DEBUG(llvm::dbgs() << "Extracted contiguous group with "
- << currentGroup.ops.size() << " operations\n");
+ << currentGroup.size() << " operations\n");
}
return result;
}
@@ -241,6 +245,8 @@ struct SLPGraphNode {
SLPGraphNode() = default;
SLPGraphNode(ArrayRef<Operation *> operations)
: ops(operations.begin(), operations.end()) {}
+
+ size_t size() const { return ops.size(); }
};
/// A graph of vectorizable operations
@@ -349,7 +355,7 @@ class SLPGraph {
LLVM_DEBUG({
llvm::dbgs() << "Topologically sorted nodes:\n";
for (auto *node : sortedNodes) {
- llvm::dbgs() << " Node with " << node->ops.size()
+ llvm::dbgs() << " Node with " << node->size()
<< " operations: " << node->ops.front()->getName() << "\n";
}
});
@@ -363,7 +369,7 @@ class SLPGraph {
if (isGoodNode(node))
continue;
- int64_t numElements = node->ops.size();
+ int64_t numElements = node->size();
Operation *op = node->ops.front();
rewriter.setInsertionPoint(op);
Location loc = op->getLoc();
@@ -467,15 +473,15 @@ class SLPGraph {
if (!node->isRoot)
continue;
llvm::dbgs() << " "
- << (isa<memref::LoadOp>(node->ops[0]) ? "LOAD" : "STORE")
- << " group with " << node->ops.size() << " operations:\n";
+ << (isa<memref::LoadOp>(node->ops.front()) ? "LOAD"
+ : "STORE")
+ << " group with " << node->size() << " operations:\n";
for (auto *op : node->ops) {
llvm::dbgs() << " " << *op << "\n";
}
llvm::dbgs() << " Users: ";
for (auto *user : node->users) {
- llvm::dbgs() << "\n Group with " << user->ops.size()
- << " operations:";
+ llvm::dbgs() << "\n Group with " << user->size() << " operations:";
for (auto *op : user->ops) {
llvm::dbgs() << "\n " << *op;
}
@@ -488,13 +494,13 @@ class SLPGraph {
for (const auto &node : nodes) {
if (node->isRoot)
continue;
- llvm::dbgs() << " Group with " << node->ops.size() << " operations:\n";
+ llvm::dbgs() << " Group with " << node->size() << " operations:\n";
for (auto *op : node->ops) {
llvm::dbgs() << " " << *op << "\n";
}
llvm::dbgs() << " Operands: ";
for (auto *operand : node->operands) {
- llvm::dbgs() << "\n Group with " << operand->ops.size()
+ llvm::dbgs() << "\n Group with " << operand->size()
<< " operations:";
for (auto *op : operand->ops) {
llvm::dbgs() << "\n " << *op;
@@ -502,8 +508,7 @@ class SLPGraph {
}
llvm::dbgs() << "\n Users: ";
for (auto *user : node->users) {
- llvm::dbgs() << "\n Group with " << user->ops.size()
- << " operations:";
+ llvm::dbgs() << "\n Group with " << user->size() << " operations:";
for (auto *op : user->ops) {
llvm::dbgs() << "\n " << *op;
}
@@ -518,6 +523,28 @@ class SLPGraph {
llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
};
+/// This pass implements the greedy SLP vectorizer. It detects consecutive
+/// operations that can be put together into vector operations. The pass works
+/// bi-directionaly, starting from reads or stores, in search of scalars to
+/// combine.
+///
+/// Pass is split into multiple steps:
+/// 1. Collect memory operation groups within same block.
+/// Group is either multiple loads uninterrupted by stores or multiple stores
+/// uninterrupted by loads.
+///
+/// 2. Extract contiguous groups from memory operation groups, based on the
+/// ops base memrefs, load/store element types, and indices.
+///
+/// 3. Build SLP graph from contiguous groups. This is done by going both
+/// top-down and bottom-up through uses/operands respectively, starting from
+/// contiguous memory operation groups.
+///
+/// 4. Vectorize SLP graph. This is done by topological sort of the graph and
+/// vectorizing each node in the order of the sort.
+///
+/// Vectorization is done by cloning the operations and mapping the operands and
+/// results.
struct GreedySLPVectorizerPass
: public mlir::vector::impl::GreedySLPVectorizerBase<
GreedySLPVectorizerPass> {
@@ -532,6 +559,10 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
}
+/// SLP vectorizer is bi-directional, so when we go top-down we can can have
+/// multiple users with the same immediate op type, this class tries to compute
+/// fingerprint for such ops based on the entire ops graph to maximize further
+/// scalar ops merging.
struct OperationsFingerprint {
OperationsFingerprint(const SLPGraph &graph) : graph(graph) {}
@@ -606,7 +637,8 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
return true;
}
-/// Build the SLP graph starting from memory operation groups
+/// Build the SLP graph starting from memory operation groups and going both
+/// top-down and bottom-up through uses/operands respectively.
static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
if (rootGroups.empty())
return SLPGraph();
@@ -623,7 +655,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
worklist.push_back(node);
LLVM_DEBUG({
- llvm::dbgs() << "Created root group node with " << node->ops.size()
+ llvm::dbgs() << "Created root group node with " << node->size()
<< " operations of type "
<< (group.isLoadGroup() ? "Load" : "Store") << "\n";
});
@@ -631,6 +663,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
OperationsFingerprint fingerprints(graph);
+ // Process node uses, going top-down.
auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
Operation *user = use.getOwner();
auto *existingNode = graph.getNodeForOp(user);
@@ -680,6 +713,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
worklist.push_back(newNode);
};
+ // Process node operands, going bottom-up.
auto processOperands = [&](SLPGraphNode *node, Value operand, int64_t index) {
Operation *srcOp = operand.getDefiningOp();
if (!srcOp)
@@ -720,7 +754,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
while (!worklist.empty()) {
SLPGraphNode *node = worklist.pop_back_val();
- LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->ops.size()
+ LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->size()
<< " operations, first op: "
<< node->ops.front()->getName() << "\n");
>From 910f7a094c061bd6f1152c194f71462a7220ec0f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 22:08:35 +0200
Subject: [PATCH 25/27] check arith.add indices
---
.../Vector/Transforms/SLPVectorizer.cpp | 13 ++++-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 54 +++++++++++++++++++
2 files changed, 66 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 24059ec355b30..aa2f3108712f1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -145,7 +145,18 @@ static bool isAdjacentIndices(Value idx1, Value idx2) {
return *c1 + 1 == *c2;
}
- // TODO: Check arith.add, affine.apply, etc
+ if (auto addOp2 = idx2.getDefiningOp<arith::AddIOp>()) {
+ if (addOp2.getLhs() == idx1 && getConstantIntValue(addOp2.getRhs()) == 1)
+ return true;
+
+ if (auto addOp1 = idx1.getDefiningOp<arith::AddIOp>()) {
+ if (addOp1.getLhs() == addOp2.getLhs() &&
+ isAdjacentIndices(addOp1.getRhs(), addOp2.getRhs()))
+ return true;
+ }
+ }
+
+ // TODO: affine.apply, etc
return false;
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 2e9298d11ed05..edb722472995d 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -75,6 +75,60 @@ func.func @read_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
}
+// CHECK-LABEL: func @read_write_add_index
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index)
+func.func @read_write_add_index(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index) {
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %ind1 = arith.addi %arg2, %c1 : index
+ %ind2 = arith.addi %arg2, %c2 : index
+ %ind3 = arith.addi %arg2, %c3 : index
+
+ %0 = memref.load %arg0[%arg2] : memref<8xi32>
+ %1 = memref.load %arg0[%ind1] : memref<8xi32>
+ %2 = memref.load %arg0[%ind2] : memref<8xi32>
+ %3 = memref.load %arg0[%ind3] : memref<8xi32>
+
+ memref.store %0, %arg0[%arg2] : memref<8xi32>
+ memref.store %1, %arg0[%ind1] : memref<8xi32>
+ memref.store %2, %arg0[%ind2] : memref<8xi32>
+ memref.store %3, %arg0[%ind3] : memref<8xi32>
+
+ return
+}
+
+
+// CHECK-LABEL: func @read_write_add_index_interleaved
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index)
+func.func @read_write_add_index_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index) {
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %ind1 = arith.addi %arg2, %c1 : index
+ %ind2 = arith.addi %arg2, %c2 : index
+ %ind3 = arith.addi %arg2, %c3 : index
+
+ %0 = memref.load %arg0[%arg2] : memref<8xi32>
+ %1 = memref.load %arg0[%ind1] : memref<8xi32>
+ %3 = memref.load %arg0[%ind3] : memref<8xi32>
+ %2 = memref.load %arg0[%ind2] : memref<8xi32>
+
+ memref.store %3, %arg0[%ind3] : memref<8xi32>
+ memref.store %0, %arg0[%arg2] : memref<8xi32>
+ memref.store %1, %arg0[%ind1] : memref<8xi32>
+ memref.store %2, %arg0[%ind2] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
>From 0db4c55a99cc7328a8e2b0233f545f9325eb393d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 23:41:34 +0200
Subject: [PATCH 26/27] fix vecor sizes
---
.../Vector/Transforms/SLPVectorizer.cpp | 24 ++++++----
mlir/test/Dialect/Vector/slp-vectorize.mlir | 47 +++++++++++++++++++
2 files changed, 61 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index aa2f3108712f1..dfd4747f615ee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -12,14 +12,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/SHA1.h"
@@ -371,15 +368,24 @@ class SLPGraph {
}
});
- auto isGoodNode = [&](SLPGraphNode *node) {
+ auto isBadNode = [&](SLPGraphNode *node) {
return node->users.empty() && node->operands.empty();
};
- IRMapping mapping;
+ // Update vec sizes if inputs are smaller.
for (auto *node : sortedNodes) {
- if (isGoodNode(node))
- continue;
+ size_t size = node->size();
+ for (auto *operand : node->operands)
+ size = std::min(size, operand->size());
+
+ node->ops.resize(size);
+ }
+
+ // Remove nodes that are not good (have users or operands)
+ llvm::erase_if(sortedNodes, isBadNode);
+ IRMapping mapping;
+ for (auto *node : sortedNodes) {
int64_t numElements = node->size();
Operation *op = node->ops.front();
rewriter.setInsertionPoint(op);
@@ -462,14 +468,12 @@ class SLPGraph {
}
for (auto *node : llvm::reverse(sortedNodes)) {
- if (isGoodNode(node))
- continue;
-
for (Operation *op : node->ops) {
rewriter.eraseOp(op);
}
}
+ LLVM_DEBUG(llvm::dbgs() << "Vectorization completed successfully\n");
return success();
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index edb722472995d..7ad077d8fd78c 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -356,3 +356,50 @@ func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
return
}
+
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func @read_read_add_write_interleaved_use
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[V0:.*]] = memref.load %[[ARG0]][%[[C3]]] : memref<8xi32>
+ // CHECK: %[[V1:.*]] = memref.load %[[ARG1]][%[[C3]]] : memref<8xi32>
+ // CHECK: call @use(%[[V0]]) : (i32) -> ()
+ // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : i32
+ // CHECK: %[[V3:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+ // CHECK: %[[V4:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+ // CHECK: %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : vector<3xi32>
+ // CHECK: vector.store %[[V5]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+ // CHECK: memref.store %[[V2]], %[[ARG0]][%[[C3]]] : memref<8xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+ call @use(%3) : (i32) -> ()
+ %11 = arith.addi %3, %7 : i32
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %8 = arith.addi %0, %4 : i32
+
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %10 = arith.addi %2, %6 : i32
+
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %9 = arith.addi %1, %5 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+
+ return
+}
>From 08dcd13cb9ff7351f4eb690fc787eb71fd306369 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 23:48:08 +0200
Subject: [PATCH 27/27] fix op insertion point
---
.../Vector/Transforms/SLPVectorizer.cpp | 12 ++++-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 44 +++++++++++++++++++
2 files changed, 55 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index dfd4747f615ee..ab5bfd94de49c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -255,6 +255,16 @@ struct SLPGraphNode {
: ops(operations.begin(), operations.end()) {}
size_t size() const { return ops.size(); }
+
+ Operation *getEarliestOp() const {
+ assert(!ops.empty() && "empty node");
+ Operation *ret = ops.front();
+ for (Operation *op : ArrayRef(ops).drop_front()) {
+ if (op->isBeforeInBlock(ret))
+ ret = op;
+ }
+ return ret;
+ }
};
/// A graph of vectorizable operations
@@ -388,7 +398,7 @@ class SLPGraph {
for (auto *node : sortedNodes) {
int64_t numElements = node->size();
Operation *op = node->ops.front();
- rewriter.setInsertionPoint(op);
+ rewriter.setInsertionPoint(node->getEarliestOp());
Location loc = op->getLoc();
auto handleNonVectorInputs = [&](ValueRange operands) {
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 7ad077d8fd78c..9d06a1faa07b2 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -403,3 +403,47 @@ func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memr
return
}
+
+
+// CHECK-LABEL: func @read_read_add_write_interleaved_use_add
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V1:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V3:.*]] = vector.extract %[[V2]][3] : i32 from vector<4xi32>
+ // CHECK: %[[V4:.*]] = arith.subi %[[V1]], %[[V3]] : i32
+ // CHECK: %[[V5:.*]] = arith.addi %[[V0]], %[[V2]] : vector<4xi32>
+ // CHECK: vector.store %[[V5]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: call @use(%[[V4]]) : (i32) -> ()
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+ %12 = arith.subi %3, %7 : i32
+ %11 = arith.addi %3, %7 : i32
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %8 = arith.addi %0, %4 : i32
+
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %10 = arith.addi %2, %6 : i32
+
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %9 = arith.addi %1, %5 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+
+ call @use(%12) : (i32) -> ()
+ return
+}
More information about the Mlir-commits
mailing list