[Mlir-commits] [mlir] [mlir][vector] MLIR SLP vectorizer (PR #140469)
Ivan Butygin
llvmlistbot at llvm.org
Sun Jun 1 03:59:33 PDT 2025
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/140469
>From aa11ef8eda1f8392b32b453f96981fff66514ca3 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:01:41 +0200
Subject: [PATCH 01/52] stubs
---
.../mlir/Dialect/Vector/Transforms/Passes.h | 3 +
.../mlir/Dialect/Vector/Transforms/Passes.td | 12 ++++
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
.../Vector/Transforms/SLPVectorizer.cpp | 63 +++++++++++++++++++
mlir/test/Dialect/Vector/slp-vectorize.mlir | 34 ++++++++++
5 files changed, 113 insertions(+)
create mode 100644 mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
create mode 100644 mlir/test/Dialect/Vector/slp-vectorize.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..43112f084dc60 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -25,6 +25,9 @@ std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
VectorMultiReductionLowering option =
VectorMultiReductionLowering::InnerParallel);
+/// Creates a pass that implements the SLP vectorizer.
+std::unique_ptr<Pass> createSLPVectorizerPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..94ccd61cb5170 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,16 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
];
}
+def SLPVectorizer : Pass<"slp-vectorizer", "ModuleOp"> {
+ let summary = "SLP Vectorizer Pass";
+ let description = [{
+ This pass implements the SLP (Superword Level Parallelism) vectorizer.
+ It detects consecutive operations that can be put together into vector
+ operations. The pass works bottom-up, across basic blocks, in search of
+ scalars to combine.
+ }];
+ let constructor = "mlir::vector::createSLPVectorizerPass()";
+ let dependentDialects = ["mlir::vector::VectorDialect"];
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..37333b739bd86 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorStep.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
+ SLPVectorizer.cpp
SubsetOpInterfaceImpl.cpp
VectorDistribute.cpp
VectorDropLeadUnitDim.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
new file mode 100644
index 0000000000000..e9f3b12bc7461
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -0,0 +1,63 @@
+//===- SLPVectorizer.cpp - SLP Vectorizer Pass ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the SLP vectorizer pass for MLIR. The pass attempts to
+// combine similar independent operations into vector operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "slp-vectorizer"
+
+namespace mlir {
+namespace vector {
+#define GEN_PASS_DEF_SLPVECTORIZER
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace vector
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// This pass implements the SLP vectorizer. It detects consecutive operations
+/// that can be put together into vector operations. The pass works bottom-up,
+/// across basic blocks, in search of scalars to combine.
+struct SLPVectorizerPass
+ : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void SLPVectorizerPass::runOnOperation() {
+ Operation *op = getOperation();
+ MLIRContext *context = &getContext();
+
+ // TODO: Implement SLP vectorization logic
+ // 1. Find candidate operations for vectorization
+ // 2. Build vectorization trees
+ // 3. Perform vectorization if profitable
+ // 4. Clean up scalar operations
+
+ LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass\n");
+ llvm::errs() << "Running SLP Vectorizer pass\n";
+}
+
+std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
+ return std::make_unique<SLPVectorizerPass>();
+}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
new file mode 100644
index 0000000000000..31543f3a76b2e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -test-slp-vectorization | FileCheck %s
+
+// CHECK-LABEL: func @basic_slp
+func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: vector.transfer_read
+ // CHECK: arith.addi
+ // CHECK: vector.transfer_write
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+ %8 = arith.addi %0, %4 : i32
+ %9 = arith.addi %1, %5 : i32
+ %10 = arith.addi %2, %6 : i32
+ %11 = arith.addi %3, %7 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
+
+ return
+}
>From 36bd924098956fa390972489ad3dda836da73135 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:20:34 +0200
Subject: [PATCH 02/52] something working
---
.../Vector/Transforms/SLPVectorizer.cpp | 113 +++++++++++++++++-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 2 +-
2 files changed, 108 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e9f3b12bc7461..b696f36c82eee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
@@ -34,27 +35,127 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
+/// A group of consecutive memory operations of the same type (load or store)
+/// that can potentially be vectorized together.
+struct MemoryOpGroup {
+ enum class Type { Load, Store };
+ Type type;
+ SmallVector<Operation *> ops;
+
+ MemoryOpGroup(Type t) : type(t) {}
+
+ bool isLoadGroup() const { return type == Type::Load; }
+ bool isStoreGroup() const { return type == Type::Store; }
+
+ size_t size() const { return ops.size(); }
+ bool empty() const { return ops.empty(); }
+};
+
/// This pass implements the SLP vectorizer. It detects consecutive operations
/// that can be put together into vector operations. The pass works bottom-up,
/// across basic blocks, in search of scalars to combine.
struct SLPVectorizerPass
: public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
void runOnOperation() override;
+
+private:
+ /// Collect all memory operations in the block into groups.
+ /// Each group contains either all loads or all stores, uninterrupted by
+ /// operations of the other type.
+ SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
};
} // namespace
+SmallVector<MemoryOpGroup>
+SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
+ SmallVector<MemoryOpGroup> groups;
+ MemoryOpGroup *currentGroup = nullptr;
+
+ LLVM_DEBUG(llvm::dbgs() << "Scanning block for memory operations...\n");
+
+ for (Operation &op : block) {
+ LLVM_DEBUG(llvm::dbgs() << "Checking operation: " << op.getName() << "\n");
+
+ // Skip non-memory operations
+ if (!isa<memref::LoadOp, memref::StoreOp>(op)) {
+ LLVM_DEBUG(llvm::dbgs() << " Not a memory operation\n");
+ continue;
+ }
+
+ bool isLoad = isa<memref::LoadOp>(op);
+ MemoryOpGroup::Type type =
+ isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
+
+ LLVM_DEBUG(llvm::dbgs()
+ << " Found " << (isLoad ? "load" : "store") << " operation\n");
+
+ // Start a new group if:
+ // 1. We don't have a current group, or
+ // 2. The current operation is a different type than the current group
+ if (!currentGroup || currentGroup->type != type) {
+ LLVM_DEBUG(llvm::dbgs() << " Starting new group\n");
+ groups.emplace_back(type);
+ currentGroup = &groups.back();
+ }
+
+ currentGroup->ops.push_back(&op);
+ }
+
+ // Remove empty groups
+ groups.erase(std::remove_if(groups.begin(), groups.end(),
+ [](const MemoryOpGroup &g) { return g.empty(); }),
+ groups.end());
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Found " << groups.size() << " memory operation groups:\n";
+ for (const auto &group : groups) {
+ llvm::dbgs() << " Group type: "
+ << (group.isLoadGroup() ? "Load" : "Store")
+ << ", size: " << group.size() << "\n";
+ }
+ });
+
+ return groups;
+}
+
void SLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
MLIRContext *context = &getContext();
- // TODO: Implement SLP vectorization logic
- // 1. Find candidate operations for vectorization
- // 2. Build vectorization trees
- // 3. Perform vectorization if profitable
- // 4. Clean up scalar operations
+ LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass on operation: "
+ << op->getName() << "\n");
+
+ // Process each function in the module
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region) {
+ for (Operation &op : block) {
+ // If this is a function, process its body
+ if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Processing function: " << funcOp.getName() << "\n");
+
+ // Process each block in the function
+ for (Block &funcBlock : funcOp.getBody()) {
+ // Collect memory operation groups
+ SmallVector<MemoryOpGroup> groups =
+ collectMemoryOpGroups(funcBlock);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Found " << groups.size()
+ << " memory operation groups:\n";
+ for (const auto &group : groups) {
+ llvm::dbgs() << " Group type: "
+ << (group.isLoadGroup() ? "Load" : "Store")
+ << ", size: " << group.size() << "\n";
+ }
+ });
+ }
+ }
+ }
+ }
+ }
- LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass\n");
llvm::errs() << "Running SLP Vectorizer pass\n";
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 31543f3a76b2e..a07dd05dd16aa 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-slp-vectorization | FileCheck %s
+// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
// CHECK-LABEL: func @basic_slp
func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
>From 5b220a7a43b2c1c32bed04e662946aecd5ee29b1 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 17 May 2025 23:29:20 +0200
Subject: [PATCH 03/52] block walk
---
.../Vector/Transforms/SLPVectorizer.cpp | 64 ++++---------------
1 file changed, 11 insertions(+), 53 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index b696f36c82eee..bec5f9d90b21b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -72,29 +72,19 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
SmallVector<MemoryOpGroup> groups;
MemoryOpGroup *currentGroup = nullptr;
- LLVM_DEBUG(llvm::dbgs() << "Scanning block for memory operations...\n");
-
for (Operation &op : block) {
- LLVM_DEBUG(llvm::dbgs() << "Checking operation: " << op.getName() << "\n");
-
// Skip non-memory operations
- if (!isa<memref::LoadOp, memref::StoreOp>(op)) {
- LLVM_DEBUG(llvm::dbgs() << " Not a memory operation\n");
+ if (!isa<memref::LoadOp, memref::StoreOp>(op))
continue;
- }
bool isLoad = isa<memref::LoadOp>(op);
MemoryOpGroup::Type type =
isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
- LLVM_DEBUG(llvm::dbgs()
- << " Found " << (isLoad ? "load" : "store") << " operation\n");
-
// Start a new group if:
// 1. We don't have a current group, or
// 2. The current operation is a different type than the current group
if (!currentGroup || currentGroup->type != type) {
- LLVM_DEBUG(llvm::dbgs() << " Starting new group\n");
groups.emplace_back(type);
currentGroup = &groups.back();
}
@@ -107,15 +97,6 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
[](const MemoryOpGroup &g) { return g.empty(); }),
groups.end());
- LLVM_DEBUG({
- llvm::dbgs() << "Found " << groups.size() << " memory operation groups:\n";
- for (const auto &group : groups) {
- llvm::dbgs() << " Group type: "
- << (group.isLoadGroup() ? "Load" : "Store")
- << ", size: " << group.size() << "\n";
- }
- });
-
return groups;
}
@@ -123,40 +104,17 @@ void SLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
MLIRContext *context = &getContext();
- LLVM_DEBUG(llvm::dbgs() << "Running SLP Vectorizer pass on operation: "
- << op->getName() << "\n");
-
- // Process each function in the module
- for (Region ®ion : op->getRegions()) {
- for (Block &block : region) {
- for (Operation &op : block) {
- // If this is a function, process its body
- if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Processing function: " << funcOp.getName() << "\n");
-
- // Process each block in the function
- for (Block &funcBlock : funcOp.getBody()) {
- // Collect memory operation groups
- SmallVector<MemoryOpGroup> groups =
- collectMemoryOpGroups(funcBlock);
-
- LLVM_DEBUG({
- llvm::dbgs() << "Found " << groups.size()
- << " memory operation groups:\n";
- for (const auto &group : groups) {
- llvm::dbgs() << " Group type: "
- << (group.isLoadGroup() ? "Load" : "Store")
- << ", size: " << group.size() << "\n";
- }
- });
- }
- }
- }
- }
- }
+ // Walk all blocks recursively
+ op->walk([&](Block *block) {
+ LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
+ << block->getParentOp()->getName() << "\n");
- llvm::errs() << "Running SLP Vectorizer pass\n";
+ // Collect memory operation groups
+ SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
+
+ LLVM_DEBUG(llvm::dbgs() << "Found " << groups.size()
+ << " memory operation groups in block\n");
+ });
}
std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
>From a1a52f482d4ab546e84a44d7e81aaab0941e772a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:10:20 +0200
Subject: [PATCH 04/52] contiguous groups
---
.../Vector/Transforms/SLPVectorizer.cpp | 106 +++++++++++++++++-
1 file changed, 104 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index bec5f9d90b21b..f46dc71537ef3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -51,6 +51,96 @@ struct MemoryOpGroup {
bool empty() const { return ops.empty(); }
};
+// Extract contiguous groups from a MemoryOpGroup
+SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
+ SmallVector<MemoryOpGroup> result;
+ if (group.ops.empty())
+ return result;
+
+ // Keep track of which operations we've processed
+ DenseSet<Operation *> processedOps;
+
+ // Process each operation
+ for (Operation *op : group.ops) {
+ // Skip if we've already processed this operation
+ if (processedOps.contains(op))
+ continue;
+
+ // Get base and index of current operation
+ Value base;
+ int64_t index = -1;
+ if (group.isLoadGroup()) {
+ auto loadOp = cast<memref::LoadOp>(op);
+ if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
+ index = *value;
+ base = loadOp.getMemRef();
+ }
+ } else {
+ auto storeOp = cast<memref::StoreOp>(op);
+ if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
+ index = *value;
+ base = storeOp.getMemRef();
+ }
+ }
+ if (index == -1)
+ continue;
+
+ // Start a new group with this operation
+ result.emplace_back(group.type);
+ MemoryOpGroup ¤tGroup = result.back();
+ currentGroup.ops.push_back(op);
+ processedOps.insert(op);
+
+ LLVM_DEBUG(llvm::dbgs() << "Starting new group at base " << base
+ << " index " << index << "\n");
+
+ // Try to find operations with adjacent indices
+ bool foundMore;
+ do {
+ foundMore = false;
+ // Look for operations with index+1
+ for (Operation *otherOp : group.ops) {
+ if (processedOps.contains(otherOp))
+ continue;
+
+ Value otherBase;
+ int64_t otherIndex = -1;
+ if (group.isLoadGroup()) {
+ auto loadOp = cast<memref::LoadOp>(otherOp);
+ if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
+ otherIndex = *value;
+ otherBase = loadOp.getMemRef();
+ }
+ } else {
+ auto storeOp = cast<memref::StoreOp>(otherOp);
+ if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
+ otherIndex = *value;
+ otherBase = storeOp.getMemRef();
+ }
+ }
+
+ // Check if this operation has the same base and adjacent index
+ if (otherIndex != -1 && otherBase == base &&
+ otherIndex == currentGroup.ops.size()) {
+ currentGroup.ops.push_back(otherOp);
+ processedOps.insert(otherOp);
+ foundMore = true;
+ LLVM_DEBUG(llvm::dbgs()
+ << "Added operation with index " << otherIndex << "\n");
+ break;
+ }
+ }
+ } while (foundMore);
+ }
+
+ // Remove empty groups
+ result.erase(std::remove_if(result.begin(), result.end(),
+ [](const MemoryOpGroup &g) { return g.empty(); }),
+ result.end());
+
+ return result;
+}
+
/// This pass implements the SLP vectorizer. It detects consecutive operations
/// that can be put together into vector operations. The pass works bottom-up,
/// across basic blocks, in search of scalars to combine.
@@ -112,8 +202,20 @@ void SLPVectorizerPass::runOnOperation() {
// Collect memory operation groups
SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
- LLVM_DEBUG(llvm::dbgs() << "Found " << groups.size()
- << " memory operation groups in block\n");
+ // Process each group to find contiguous sequences
+ for (const auto &group : groups) {
+ SmallVector<MemoryOpGroup> contiguousGroups =
+ extractContiguousGroups(group);
+ LLVM_DEBUG({
+ llvm::dbgs() << "Found " << contiguousGroups.size()
+ << " contiguous groups in "
+ << (group.isLoadGroup() ? "load" : "store") << " group\n";
+ for (const auto &contigGroup : contiguousGroups) {
+ llvm::dbgs() << " Contiguous group with " << contigGroup.size()
+ << " operations\n";
+ }
+ });
+ }
});
}
>From 8b266979e83431e3d06bb1f665f69b7f630c2bb9 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:14:53 +0200
Subject: [PATCH 05/52] refac
---
.../Vector/Transforms/SLPVectorizer.cpp | 55 ++++++++-----------
1 file changed, 22 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index f46dc71537ef3..9a0ba5264bc40 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -51,6 +51,18 @@ struct MemoryOpGroup {
bool empty() const { return ops.empty(); }
};
+// Helper function to extract base and index from a memory operation
+std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
+ if (auto value = getConstantIntValue(loadOp.getIndices().front()))
+ return std::make_pair(loadOp.getMemRef(), *value);
+ } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
+ if (auto value = getConstantIntValue(storeOp.getIndices().front()))
+ return std::make_pair(storeOp.getMemRef(), *value);
+ }
+ return std::nullopt;
+}
+
// Extract contiguous groups from a MemoryOpGroup
SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
SmallVector<MemoryOpGroup> result;
@@ -67,24 +79,12 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
continue;
// Get base and index of current operation
- Value base;
- int64_t index = -1;
- if (group.isLoadGroup()) {
- auto loadOp = cast<memref::LoadOp>(op);
- if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
- index = *value;
- base = loadOp.getMemRef();
- }
- } else {
- auto storeOp = cast<memref::StoreOp>(op);
- if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
- index = *value;
- base = storeOp.getMemRef();
- }
- }
- if (index == -1)
+ auto baseAndIndex = getBaseAndIndex(op);
+ if (!baseAndIndex)
continue;
+ auto [base, index] = *baseAndIndex;
+
// Start a new group with this operation
result.emplace_back(group.type);
MemoryOpGroup ¤tGroup = result.back();
@@ -103,25 +103,14 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
if (processedOps.contains(otherOp))
continue;
- Value otherBase;
- int64_t otherIndex = -1;
- if (group.isLoadGroup()) {
- auto loadOp = cast<memref::LoadOp>(otherOp);
- if (auto value = getConstantIntValue(loadOp.getIndices().front())) {
- otherIndex = *value;
- otherBase = loadOp.getMemRef();
- }
- } else {
- auto storeOp = cast<memref::StoreOp>(otherOp);
- if (auto value = getConstantIntValue(storeOp.getIndices().front())) {
- otherIndex = *value;
- otherBase = storeOp.getMemRef();
- }
- }
+ auto otherBaseAndIndex = getBaseAndIndex(otherOp);
+ if (!otherBaseAndIndex)
+ continue;
+
+ auto [otherBase, otherIndex] = *otherBaseAndIndex;
// Check if this operation has the same base and adjacent index
- if (otherIndex != -1 && otherBase == base &&
- otherIndex == currentGroup.ops.size()) {
+ if (otherBase == base && otherIndex == currentGroup.ops.size()) {
currentGroup.ops.push_back(otherOp);
processedOps.insert(otherOp);
foundMore = true;
>From 903fee0da4587e36348f4d8359759d3de4cab652 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:35:42 +0200
Subject: [PATCH 06/52] SLPGraph
---
.../Vector/Transforms/SLPVectorizer.cpp | 162 ++++++++++++++++++
1 file changed, 162 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 9a0ba5264bc40..4355dc33648c0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -130,6 +130,160 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
return result;
}
+/// A node in the SLP graph representing a vectorizable operation
+struct SLPGraphNode {
+ Operation *op;
+ DenseSet<SLPGraphNode *> users;
+ DenseSet<SLPGraphNode *> operands;
+ bool isRoot = false;
+
+ SLPGraphNode(Operation *op) : op(op) {}
+};
+
+/// A graph of vectorizable operations
+class SLPGraph {
+public:
+ SLPGraph() = default;
+ ~SLPGraph() {
+ for (auto *node : nodes)
+ delete node;
+ }
+
+ /// Add a new node to the graph
+ SLPGraphNode *addNode(Operation *op) {
+ nodes.push_back(new SLPGraphNode(op));
+ return nodes.back();
+ }
+
+ /// Add a root node (memory operation)
+ SLPGraphNode *addRoot(Operation *op) {
+ auto *node = addNode(op);
+ node->isRoot = true;
+ return node;
+ }
+
+ /// Add a dependency edge between nodes
+ void addEdge(SLPGraphNode *from, SLPGraphNode *to) {
+ from->users.insert(to);
+ to->operands.insert(from);
+ }
+
+ /// Get all root nodes
+ SmallVector<SLPGraphNode *> getRoots() const {
+ SmallVector<SLPGraphNode *> roots;
+ for (auto *node : nodes)
+ if (node->isRoot)
+ roots.push_back(node);
+ return roots;
+ }
+
+ /// Print the graph structure
+ void print() const {
+ llvm::dbgs() << "SLP Graph Structure:\n";
+ llvm::dbgs() << "===================\n";
+
+ // First print all roots
+ llvm::dbgs() << "Roots:\n";
+ for (auto *node : nodes) {
+ if (!node->isRoot)
+ continue;
+ llvm::dbgs() << " " << *node->op << "\n";
+ llvm::dbgs() << " Users: ";
+ for (auto *user : node->users) {
+ llvm::dbgs() << "\n " << *user->op;
+ }
+ llvm::dbgs() << "\n";
+ }
+
+ // Then print all non-root nodes
+ llvm::dbgs() << "\nNon-root nodes:\n";
+ for (auto *node : nodes) {
+ if (node->isRoot)
+ continue;
+ llvm::dbgs() << " " << *node->op << "\n";
+ llvm::dbgs() << " Operands: ";
+ for (auto *operand : node->operands) {
+ llvm::dbgs() << "\n " << *operand->op;
+ }
+ llvm::dbgs() << "\n Users: ";
+ for (auto *user : node->users) {
+ llvm::dbgs() << "\n " << *user->op;
+ }
+ llvm::dbgs() << "\n";
+ }
+ llvm::dbgs() << "===================\n";
+ }
+
+private:
+ SmallVector<SLPGraphNode *> nodes;
+};
+
+/// Build the SLP graph starting from memory operation roots
+SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+ SLPGraph graph;
+ DenseMap<Operation *, SLPGraphNode *> opToNode;
+
+ // First, add all memory operations as roots
+ for (const auto &group : rootGroups) {
+ for (Operation *op : group.ops) {
+ opToNode[op] = graph.addRoot(op);
+ }
+ }
+
+ // Process each root group to build the graph
+ for (const auto &group : rootGroups) {
+ for (Operation *rootOp : group.ops) {
+ // Get the value produced by this memory operation
+ Value rootValue = group.isLoadGroup()
+ ? cast<memref::LoadOp>(rootOp).getResult()
+ : cast<memref::StoreOp>(rootOp).getValue();
+
+ // Find all users of this value
+ for (Operation *user : rootValue.getUsers()) {
+ // Skip if we've already processed this operation
+ if (opToNode.contains(user))
+ continue;
+
+ // Check if this is a vectorizable operation
+ if (isa<arith::AddFOp, arith::AddIOp, arith::SubFOp, arith::SubIOp,
+ arith::MulFOp, arith::MulIOp>(user)) {
+ // Check if at least one other operand is already in the graph
+ bool hasGraphOperand = false;
+ for (Value operand : user->getOperands()) {
+ if (operand == rootValue)
+ continue;
+ if (auto *defOp = operand.getDefiningOp()) {
+ if (opToNode.contains(defOp)) {
+ hasGraphOperand = true;
+ break;
+ }
+ }
+ }
+
+ // Only add the operation if it has at least one other operand in the
+ // graph
+ if (hasGraphOperand) {
+ auto *node = graph.addNode(user);
+ opToNode[user] = node;
+ graph.addEdge(opToNode[rootOp], node);
+
+ // Add edges from other operands that are in the graph
+ for (Value operand : user->getOperands()) {
+ if (auto *defOp = operand.getDefiningOp()) {
+ if (opToNode.contains(defOp)) {
+ graph.addEdge(opToNode[defOp], node);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return graph;
+}
+
/// This pass implements the SLP vectorizer. It detects consecutive operations
/// that can be put together into vector operations. The pass works bottom-up,
/// across basic blocks, in search of scalars to combine.
@@ -192,6 +346,7 @@ void SLPVectorizerPass::runOnOperation() {
SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
// Process each group to find contiguous sequences
+ SmallVector<MemoryOpGroup> rootGroups;
for (const auto &group : groups) {
SmallVector<MemoryOpGroup> contiguousGroups =
extractContiguousGroups(group);
@@ -204,7 +359,14 @@ void SLPVectorizerPass::runOnOperation() {
<< " operations\n";
}
});
+ rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
}
+
+ // Build the SLP graph from root groups
+ SLPGraph graph = buildSLPGraph(rootGroups);
+
+ // Print the graph structure
+ LLVM_DEBUG(graph.print());
});
}
>From 4f92ea6fffa572640cc70d82f06f336b99484642 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 00:53:58 +0200
Subject: [PATCH 07/52] SLPGraph
---
.../Vector/Transforms/SLPVectorizer.cpp | 129 ++++++------------
1 file changed, 38 insertions(+), 91 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 4355dc33648c0..3c4fc3a377244 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -130,29 +130,28 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
return result;
}
-/// A node in the SLP graph representing a vectorizable operation
+/// A node in the SLP graph representing a group of vectorizable operations
struct SLPGraphNode {
- Operation *op;
+ SmallVector<Operation *> ops;
DenseSet<SLPGraphNode *> users;
DenseSet<SLPGraphNode *> operands;
bool isRoot = false;
- SLPGraphNode(Operation *op) : op(op) {}
+ SLPGraphNode() = default;
+ SLPGraphNode(Operation *op) { ops.push_back(op); }
+ void addOp(Operation *op) { ops.push_back(op); }
};
/// A graph of vectorizable operations
class SLPGraph {
public:
SLPGraph() = default;
- ~SLPGraph() {
- for (auto *node : nodes)
- delete node;
- }
+ ~SLPGraph() = default;
/// Add a new node to the graph
SLPGraphNode *addNode(Operation *op) {
- nodes.push_back(new SLPGraphNode(op));
- return nodes.back();
+ nodes.push_back(std::make_unique<SLPGraphNode>(op));
+ return nodes.back().get();
}
/// Add a root node (memory operation)
@@ -171,9 +170,9 @@ class SLPGraph {
/// Get all root nodes
SmallVector<SLPGraphNode *> getRoots() const {
SmallVector<SLPGraphNode *> roots;
- for (auto *node : nodes)
+ for (const auto &node : nodes)
if (node->isRoot)
- roots.push_back(node);
+ roots.push_back(node.get());
return roots;
}
@@ -184,30 +183,50 @@ class SLPGraph {
// First print all roots
llvm::dbgs() << "Roots:\n";
- for (auto *node : nodes) {
+ for (const auto &node : nodes) {
if (!node->isRoot)
continue;
- llvm::dbgs() << " " << *node->op << "\n";
+ llvm::dbgs() << " "
+ << (isa<memref::LoadOp>(node->ops[0]) ? "LOAD" : "STORE")
+ << " group with " << node->ops.size() << " operations:\n";
+ for (auto *op : node->ops) {
+ llvm::dbgs() << " " << *op << "\n";
+ }
llvm::dbgs() << " Users: ";
for (auto *user : node->users) {
- llvm::dbgs() << "\n " << *user->op;
+ llvm::dbgs() << "\n Group with " << user->ops.size()
+ << " operations:";
+ for (auto *op : user->ops) {
+ llvm::dbgs() << "\n " << *op;
+ }
}
llvm::dbgs() << "\n";
}
// Then print all non-root nodes
llvm::dbgs() << "\nNon-root nodes:\n";
- for (auto *node : nodes) {
+ for (const auto &node : nodes) {
if (node->isRoot)
continue;
- llvm::dbgs() << " " << *node->op << "\n";
+ llvm::dbgs() << " Group with " << node->ops.size() << " operations:\n";
+ for (auto *op : node->ops) {
+ llvm::dbgs() << " " << *op << "\n";
+ }
llvm::dbgs() << " Operands: ";
for (auto *operand : node->operands) {
- llvm::dbgs() << "\n " << *operand->op;
+ llvm::dbgs() << "\n Group with " << operand->ops.size()
+ << " operations:";
+ for (auto *op : operand->ops) {
+ llvm::dbgs() << "\n " << *op;
+ }
}
llvm::dbgs() << "\n Users: ";
for (auto *user : node->users) {
- llvm::dbgs() << "\n " << *user->op;
+ llvm::dbgs() << "\n Group with " << user->ops.size()
+ << " operations:";
+ for (auto *op : user->ops) {
+ llvm::dbgs() << "\n " << *op;
+ }
}
llvm::dbgs() << "\n";
}
@@ -215,75 +234,9 @@ class SLPGraph {
}
private:
- SmallVector<SLPGraphNode *> nodes;
+ SmallVector<std::unique_ptr<SLPGraphNode>> nodes;
};
-/// Build the SLP graph starting from memory operation roots
-SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
- SLPGraph graph;
- DenseMap<Operation *, SLPGraphNode *> opToNode;
-
- // First, add all memory operations as roots
- for (const auto &group : rootGroups) {
- for (Operation *op : group.ops) {
- opToNode[op] = graph.addRoot(op);
- }
- }
-
- // Process each root group to build the graph
- for (const auto &group : rootGroups) {
- for (Operation *rootOp : group.ops) {
- // Get the value produced by this memory operation
- Value rootValue = group.isLoadGroup()
- ? cast<memref::LoadOp>(rootOp).getResult()
- : cast<memref::StoreOp>(rootOp).getValue();
-
- // Find all users of this value
- for (Operation *user : rootValue.getUsers()) {
- // Skip if we've already processed this operation
- if (opToNode.contains(user))
- continue;
-
- // Check if this is a vectorizable operation
- if (isa<arith::AddFOp, arith::AddIOp, arith::SubFOp, arith::SubIOp,
- arith::MulFOp, arith::MulIOp>(user)) {
- // Check if at least one other operand is already in the graph
- bool hasGraphOperand = false;
- for (Value operand : user->getOperands()) {
- if (operand == rootValue)
- continue;
- if (auto *defOp = operand.getDefiningOp()) {
- if (opToNode.contains(defOp)) {
- hasGraphOperand = true;
- break;
- }
- }
- }
-
- // Only add the operation if it has at least one other operand in the
- // graph
- if (hasGraphOperand) {
- auto *node = graph.addNode(user);
- opToNode[user] = node;
- graph.addEdge(opToNode[rootOp], node);
-
- // Add edges from other operands that are in the graph
- for (Value operand : user->getOperands()) {
- if (auto *defOp = operand.getDefiningOp()) {
- if (opToNode.contains(defOp)) {
- graph.addEdge(opToNode[defOp], node);
- }
- }
- }
- }
- }
- }
- }
- }
-
- return graph;
-}
-
/// This pass implements the SLP vectorizer. It detects consecutive operations
/// that can be put together into vector operations. The pass works bottom-up,
/// across basic blocks, in search of scalars to combine.
@@ -361,12 +314,6 @@ void SLPVectorizerPass::runOnOperation() {
});
rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
}
-
- // Build the SLP graph from root groups
- SLPGraph graph = buildSLPGraph(rootGroups);
-
- // Print the graph structure
- LLVM_DEBUG(graph.print());
});
}
>From 8dcc9bc4cf77beba2f1142cad92c011e52ae0271 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 01:04:27 +0200
Subject: [PATCH 08/52] work
---
.../Vector/Transforms/SLPVectorizer.cpp | 48 ++++++++++++++++---
1 file changed, 41 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3c4fc3a377244..8e49b622ac39b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -138,8 +138,8 @@ struct SLPGraphNode {
bool isRoot = false;
SLPGraphNode() = default;
- SLPGraphNode(Operation *op) { ops.push_back(op); }
- void addOp(Operation *op) { ops.push_back(op); }
+ SLPGraphNode(ArrayRef<Operation *> operations)
+ : ops(operations.begin(), operations.end()) {}
};
/// A graph of vectorizable operations
@@ -148,15 +148,23 @@ class SLPGraph {
SLPGraph() = default;
~SLPGraph() = default;
+ // Delete copy constructor and assignment operator
+ SLPGraph(const SLPGraph &) = delete;
+ SLPGraph &operator=(const SLPGraph &) = delete;
+
+ // Allow move operations
+ SLPGraph(SLPGraph &&) = default;
+ SLPGraph &operator=(SLPGraph &&) = default;
+
/// Add a new node to the graph
- SLPGraphNode *addNode(Operation *op) {
- nodes.push_back(std::make_unique<SLPGraphNode>(op));
+ SLPGraphNode *addNode(ArrayRef<Operation *> operations) {
+ nodes.push_back(std::make_unique<SLPGraphNode>(operations));
return nodes.back().get();
}
/// Add a root node (memory operation)
- SLPGraphNode *addRoot(Operation *op) {
- auto *node = addNode(op);
+ SLPGraphNode *addRoot(ArrayRef<Operation *> operations) {
+ auto *node = addNode(operations);
node->isRoot = true;
return node;
}
@@ -251,7 +259,25 @@ struct SLPVectorizerPass
SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
};
-} // namespace
+/// Build the SLP graph starting from memory operation groups
+SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+ SLPGraph graph;
+
+ // First, create nodes for each contiguous memory operation group
+ for (const auto &group : rootGroups) {
+ // Create a new node for this group
+ auto *node = graph.addRoot(group.ops);
+ node->isRoot = true;
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Created " << (group.isLoadGroup() ? "LOAD" : "STORE")
+ << " group node with " << node->ops.size()
+ << " operations\n";
+ });
+ }
+
+ return graph;
+}
SmallVector<MemoryOpGroup>
SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
@@ -314,9 +340,17 @@ void SLPVectorizerPass::runOnOperation() {
});
rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
}
+
+ // Build the SLP graph from root groups
+ SLPGraph graph = buildSLPGraph(rootGroups);
+
+ // Print the graph structure
+ LLVM_DEBUG(graph.print());
});
}
+} // namespace
+
std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
return std::make_unique<SLPVectorizerPass>();
}
>From cee30a40bab9d2d474970880cc6b8c859a23ba2d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 12:55:28 +0200
Subject: [PATCH 09/52] fingerprinting
---
.../Vector/Transforms/SLPVectorizer.cpp | 170 ++++++++++++++++--
1 file changed, 158 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 8e49b622ac39b..3e6a4ca05f87d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -21,6 +21,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/SHA1.h"
#define DEBUG_TYPE "slp-vectorizer"
@@ -64,7 +65,8 @@ std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
}
// Extract contiguous groups from a MemoryOpGroup
-SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
+static SmallVector<MemoryOpGroup>
+extractContiguousGroups(const MemoryOpGroup &group) {
SmallVector<MemoryOpGroup> result;
if (group.ops.empty())
return result;
@@ -133,8 +135,8 @@ SmallVector<MemoryOpGroup> extractContiguousGroups(const MemoryOpGroup &group) {
/// A node in the SLP graph representing a group of vectorizable operations
struct SLPGraphNode {
SmallVector<Operation *> ops;
- DenseSet<SLPGraphNode *> users;
- DenseSet<SLPGraphNode *> operands;
+ llvm::SmallDenseSet<SLPGraphNode *> users;
+ llvm::SmallDenseSet<SLPGraphNode *> operands;
bool isRoot = false;
SLPGraphNode() = default;
@@ -148,11 +150,9 @@ class SLPGraph {
SLPGraph() = default;
~SLPGraph() = default;
- // Delete copy constructor and assignment operator
SLPGraph(const SLPGraph &) = delete;
SLPGraph &operator=(const SLPGraph &) = delete;
- // Allow move operations
SLPGraph(SLPGraph &&) = default;
SLPGraph &operator=(SLPGraph &&) = default;
@@ -259,21 +259,168 @@ struct SLPVectorizerPass
SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
};
+static bool isVectorizable(Operation *op) {
+ return OpTrait::hasElementwiseMappableTraits(op);
+}
+
+using Fingerprint = std::array<uint8_t, 20>;
+
+template <typename T>
+static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
+ hasher.update(
+ ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
+}
+
+struct OperationsFingerprint {
+ OperationsFingerprint(
+ const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode)
+ : opToNode(opToNode) {}
+
+ Fingerprint getFingerprint(Operation *op) {
+ auto it = fingerprints.find(op);
+ if (it != fingerprints.end())
+ return it->second;
+
+ SmallVector<Operation *> worklist;
+ SmallVector<Operation *> toposortedOps;
+ worklist.emplace_back(op);
+ while (!worklist.empty()) {
+ Operation *op = worklist.pop_back_val();
+ toposortedOps.emplace_back(op);
+ if (opToNode.contains(op))
+ continue;
+
+ for (Value operand : op->getOperands()) {
+ auto *defOp = operand.getDefiningOp();
+ if (!defOp || !isVectorizable(defOp))
+ continue;
+
+ toposortedOps.emplace_back(defOp);
+ worklist.emplace_back(defOp);
+ }
+ }
+
+ for (Operation *op : llvm::reverse(toposortedOps)) {
+ llvm::SHA1 hasher;
+ addDataToHash(hasher, op->getName().getTypeID());
+ addDataToHash(hasher, op->getRawDictionaryAttrs());
+ addDataToHash(hasher, op->hashProperties());
+ for (Value operand : op->getOperands()) {
+ auto *defOp = operand.getDefiningOp();
+ if (!defOp)
+ continue;
+
+ auto it1 = opToNode.find(defOp);
+ if (it1 != opToNode.end()) {
+ addDataToHash(hasher, it1->second);
+ continue;
+ }
+
+ auto it2 = fingerprints.find(defOp);
+ if (it2 != fingerprints.end()) {
+ addDataToHash(hasher, it2->second);
+ continue;
+ }
+ }
+ fingerprints[op] = hasher.result();
+ }
+
+ return fingerprints[op];
+ }
+
+ void invalidate(Operation *op) {
+ if (fingerprints.contains(op))
+ fingerprints.clear();
+ }
+
+ const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode;
+ DenseMap<Operation *, Fingerprint> fingerprints;
+};
+
+static bool isEquivalent(Operation *op1, Operation *op2) {
+ if (op1->getName() != op2->getName())
+ return false;
+
+ if (op1->getRawDictionaryAttrs() != op2->getRawDictionaryAttrs())
+ return false;
+
+ return true;
+}
+
/// Build the SLP graph starting from memory operation groups
-SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
SLPGraph graph;
+ llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
+
+ SmallVector<SLPGraphNode *> worklist;
// First, create nodes for each contiguous memory operation group
for (const auto &group : rootGroups) {
- // Create a new node for this group
auto *node = graph.addRoot(group.ops);
- node->isRoot = true;
+ for (Operation *op : group.ops)
+ opToNode[op] = node;
+
+ worklist.push_back(node);
LLVM_DEBUG({
- llvm::dbgs() << "Created " << (group.isLoadGroup() ? "LOAD" : "STORE")
- << " group node with " << node->ops.size()
- << " operations\n";
+ llvm::dbgs() << "Created root group node with " << node->ops.size()
+ << " operations of type "
+ << (group.type == MemoryOpGroup::Type::Load ? "Load"
+ : "Store")
+ << "\n";
});
+
+ OperationsFingerprint fingerprints(opToNode);
+
+ auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
+ Operation *user = use.getOwner();
+ if (opToNode.contains(user) || !isVectorizable(user))
+ return;
+
+ Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
+
+ SmallVector<Operation *> currentOps;
+ currentOps.emplace_back(user);
+ for (Operation *op : ArrayRef(node->ops).drop_front()) {
+ Operation *found = nullptr;
+ for (OpOperand &opUse : op->getUses()) {
+ if (opUse.getOperandNumber() != use.getOperandNumber())
+ continue;
+
+ Operation *useOwner = opUse.getOwner();
+ if (!isEquivalent(useOwner, user) ||
+ fingerprints.getFingerprint(useOwner) != expectedFingerprint)
+ continue;
+
+ found = useOwner;
+ break;
+ }
+ if (!found)
+ break;
+
+ currentOps.push_back(found);
+ }
+
+ if (currentOps.size() == 1)
+ return;
+
+ auto *newNode = graph.addNode(currentOps);
+ graph.addEdge(node, newNode);
+ for (Operation *op : currentOps) {
+ opToNode[op] = newNode;
+ fingerprints.invalidate(op);
+ }
+
+ worklist.push_back(newNode);
+ };
+
+ while (!worklist.empty()) {
+ SLPGraphNode *node = worklist.pop_back_val();
+
+ Operation *op = node->ops.front();
+ for (OpOperand &use : op->getUses())
+ processUse(node, use);
+ }
}
return graph;
@@ -314,7 +461,6 @@ SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
void SLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
- MLIRContext *context = &getContext();
// Walk all blocks recursively
op->walk([&](Block *block) {
>From 0045a6f5abea75220af7f4777f658a2c83ae0415 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 13:58:50 +0200
Subject: [PATCH 10/52] graph
---
.../Vector/Transforms/SLPVectorizer.cpp | 106 ++++++++++--------
1 file changed, 62 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3e6a4ca05f87d..28c53efea7512 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -185,7 +185,7 @@ class SLPGraph {
}
/// Print the graph structure
- void print() const {
+ [[maybe_unused]] void print() const {
llvm::dbgs() << "SLP Graph Structure:\n";
llvm::dbgs() << "===================\n";
@@ -348,7 +348,12 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
}
/// Build the SLP graph starting from memory operation groups
-static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
+static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
+ if (rootGroups.empty())
+ return SLPGraph();
+
+ LLVM_DEBUG(llvm::dbgs() << "=== Building SLP graph from " << rootGroups.size()
+ << " root groups ===\n");
SLPGraph graph;
llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
@@ -365,61 +370,74 @@ static SLPGraph buildSLPGraph(const SmallVector<MemoryOpGroup> &rootGroups) {
LLVM_DEBUG({
llvm::dbgs() << "Created root group node with " << node->ops.size()
<< " operations of type "
- << (group.type == MemoryOpGroup::Type::Load ? "Load"
- : "Store")
- << "\n";
+ << (group.isLoadGroup() ? "Load" : "Store") << "\n";
});
+ }
- OperationsFingerprint fingerprints(opToNode);
-
- auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
- Operation *user = use.getOwner();
- if (opToNode.contains(user) || !isVectorizable(user))
- return;
+ OperationsFingerprint fingerprints(opToNode);
+
+ auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
+ Operation *user = use.getOwner();
+ auto it = opToNode.find(user);
+ if (it != opToNode.end()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " Adding edge from " << node->ops.front()->getName()
+ << " to " << it->first->getName() << "\n");
+ graph.addEdge(node, it->second);
+ return;
+ }
- Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
+ if (!isVectorizable(user))
+ return;
- SmallVector<Operation *> currentOps;
- currentOps.emplace_back(user);
- for (Operation *op : ArrayRef(node->ops).drop_front()) {
- Operation *found = nullptr;
- for (OpOperand &opUse : op->getUses()) {
- if (opUse.getOperandNumber() != use.getOperandNumber())
- continue;
+ Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
- Operation *useOwner = opUse.getOwner();
- if (!isEquivalent(useOwner, user) ||
- fingerprints.getFingerprint(useOwner) != expectedFingerprint)
- continue;
+ SmallVector<Operation *> currentOps;
+ currentOps.emplace_back(user);
+ for (Operation *op : ArrayRef(node->ops).drop_front()) {
+ Operation *found = nullptr;
+ for (OpOperand &opUse : op->getUses()) {
+ if (opUse.getOperandNumber() != use.getOperandNumber())
+ continue;
- found = useOwner;
- break;
- }
- if (!found)
- break;
+ Operation *useOwner = opUse.getOwner();
+ if (!isEquivalent(useOwner, user) ||
+ fingerprints.getFingerprint(useOwner) != expectedFingerprint)
+ continue;
- currentOps.push_back(found);
+ found = useOwner;
+ break;
}
+ if (!found)
+ break;
- if (currentOps.size() == 1)
- return;
+ currentOps.push_back(found);
+ }
- auto *newNode = graph.addNode(currentOps);
- graph.addEdge(node, newNode);
- for (Operation *op : currentOps) {
- opToNode[op] = newNode;
- fingerprints.invalidate(op);
- }
+ if (currentOps.size() == 1)
+ return;
- worklist.push_back(newNode);
- };
+ auto *newNode = graph.addNode(currentOps);
+ graph.addEdge(node, newNode);
+ for (Operation *op : currentOps) {
+ opToNode[op] = newNode;
+ fingerprints.invalidate(op);
+ }
- while (!worklist.empty()) {
- SLPGraphNode *node = worklist.pop_back_val();
+ worklist.push_back(newNode);
+ };
+
+ while (!worklist.empty()) {
+ SLPGraphNode *node = worklist.pop_back_val();
+ LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->ops.size()
+ << " operations, first op: "
+ << node->ops.front()->getName() << "\n");
- Operation *op = node->ops.front();
- for (OpOperand &use : op->getUses())
- processUse(node, use);
+ Operation *op = node->ops.front();
+ for (OpOperand &use : op->getUses()) {
+ processUse(node, use);
+ LLVM_DEBUG(llvm::dbgs() << " Processing use in operation: "
+ << use.getOwner()->getName() << "\n");
}
}
>From 8c29c3a291506d81565f605df314cf30b82f214b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 14:05:07 +0200
Subject: [PATCH 11/52] refac
---
.../Vector/Transforms/SLPVectorizer.cpp | 41 ++++++++++---------
1 file changed, 22 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 28c53efea7512..e3b39ba10373c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -159,7 +159,10 @@ class SLPGraph {
/// Add a new node to the graph
SLPGraphNode *addNode(ArrayRef<Operation *> operations) {
nodes.push_back(std::make_unique<SLPGraphNode>(operations));
- return nodes.back().get();
+ auto *node = nodes.back().get();
+ for (Operation *op : operations)
+ opToNode[op] = node;
+ return node;
}
/// Add a root node (memory operation)
@@ -184,6 +187,12 @@ class SLPGraph {
return roots;
}
+ /// Get the node associated with an operation
+ SLPGraphNode *getNodeForOp(Operation *op) const {
+ auto it = opToNode.find(op);
+ return it != opToNode.end() ? it->second : nullptr;
+ }
+
/// Print the graph structure
[[maybe_unused]] void print() const {
llvm::dbgs() << "SLP Graph Structure:\n";
@@ -243,6 +252,7 @@ class SLPGraph {
private:
SmallVector<std::unique_ptr<SLPGraphNode>> nodes;
+ llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
};
/// This pass implements the SLP vectorizer. It detects consecutive operations
@@ -272,9 +282,7 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
}
struct OperationsFingerprint {
- OperationsFingerprint(
- const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode)
- : opToNode(opToNode) {}
+ OperationsFingerprint(const SLPGraph &graph) : graph(graph) {}
Fingerprint getFingerprint(Operation *op) {
auto it = fingerprints.find(op);
@@ -287,7 +295,7 @@ struct OperationsFingerprint {
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
toposortedOps.emplace_back(op);
- if (opToNode.contains(op))
+ if (graph.getNodeForOp(op))
continue;
for (Value operand : op->getOperands()) {
@@ -310,9 +318,9 @@ struct OperationsFingerprint {
if (!defOp)
continue;
- auto it1 = opToNode.find(defOp);
- if (it1 != opToNode.end()) {
- addDataToHash(hasher, it1->second);
+ auto *node = graph.getNodeForOp(defOp);
+ if (node) {
+ addDataToHash(hasher, node);
continue;
}
@@ -333,7 +341,7 @@ struct OperationsFingerprint {
fingerprints.clear();
}
- const llvm::SmallDenseMap<Operation *, SLPGraphNode *> &opToNode;
+ const SLPGraph &graph;
DenseMap<Operation *, Fingerprint> fingerprints;
};
@@ -355,16 +363,12 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
LLVM_DEBUG(llvm::dbgs() << "=== Building SLP graph from " << rootGroups.size()
<< " root groups ===\n");
SLPGraph graph;
- llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
SmallVector<SLPGraphNode *> worklist;
// First, create nodes for each contiguous memory operation group
for (const auto &group : rootGroups) {
auto *node = graph.addRoot(group.ops);
- for (Operation *op : group.ops)
- opToNode[op] = node;
-
worklist.push_back(node);
LLVM_DEBUG({
@@ -374,16 +378,16 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
});
}
- OperationsFingerprint fingerprints(opToNode);
+ OperationsFingerprint fingerprints(graph);
auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
Operation *user = use.getOwner();
- auto it = opToNode.find(user);
- if (it != opToNode.end()) {
+ auto *existingNode = graph.getNodeForOp(user);
+ if (existingNode) {
LLVM_DEBUG(llvm::dbgs()
<< " Adding edge from " << node->ops.front()->getName()
- << " to " << it->first->getName() << "\n");
- graph.addEdge(node, it->second);
+ << " to " << user->getName() << "\n");
+ graph.addEdge(node, existingNode);
return;
}
@@ -420,7 +424,6 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
auto *newNode = graph.addNode(currentOps);
graph.addEdge(node, newNode);
for (Operation *op : currentOps) {
- opToNode[op] = newNode;
fingerprints.invalidate(op);
}
>From bda45c9202c3b068fadc0026976bfdd3fce79daf Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 14:59:20 +0200
Subject: [PATCH 12/52] toposort
---
.../Vector/Transforms/SLPVectorizer.cpp | 89 ++++++++++++++++++-
1 file changed, 85 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e3b39ba10373c..8f0137a12d07b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -135,8 +135,8 @@ extractContiguousGroups(const MemoryOpGroup &group) {
/// A node in the SLP graph representing a group of vectorizable operations
struct SLPGraphNode {
SmallVector<Operation *> ops;
- llvm::SmallDenseSet<SLPGraphNode *> users;
- llvm::SmallDenseSet<SLPGraphNode *> operands;
+ SmallVector<SLPGraphNode *> users;
+ SmallVector<SLPGraphNode *> operands;
bool isRoot = false;
SLPGraphNode() = default;
@@ -174,8 +174,8 @@ class SLPGraph {
/// Add a dependency edge between nodes
void addEdge(SLPGraphNode *from, SLPGraphNode *to) {
- from->users.insert(to);
- to->operands.insert(from);
+ from->users.push_back(to);
+ to->operands.push_back(from);
}
/// Get all root nodes
@@ -193,6 +193,80 @@ class SLPGraph {
return it != opToNode.end() ? it->second : nullptr;
}
+ /// Topologically sort the nodes in the graph
+ SmallVector<SLPGraphNode *> topologicalSort() const {
+ SmallVector<SLPGraphNode *> result;
+ llvm::SmallDenseSet<SLPGraphNode *> visited;
+
+ SmallVector<SLPGraphNode *> stack;
+
+ // Process each node
+ for (const auto &node : nodes) {
+ if (visited.contains(node.get()))
+ continue;
+
+ stack.emplace_back(node.get());
+ while (!stack.empty()) {
+ SLPGraphNode *node = stack.pop_back_val();
+ if (visited.contains(node))
+ continue;
+
+ stack.push_back(node);
+
+ bool pushed = false;
+ for (SLPGraphNode *operand : node->operands) {
+ if (visited.contains(operand))
+ continue;
+
+ stack.push_back(operand);
+ pushed = true;
+ }
+
+ if (!pushed) {
+ visited.insert(node);
+ result.push_back(node);
+ }
+ }
+ }
+
+ return result;
+ }
+
+ /// Vectorize the operations in the graph
+ LogicalResult vectorize(IRRewriter &rewriter) {
+ if (nodes.empty())
+ return success();
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "Vectorizing SLP graph with " << nodes.size() << " nodes\n");
+
+ // Get topologically sorted nodes
+ SmallVector<SLPGraphNode *> sortedNodes = topologicalSort();
+ if (sortedNodes.empty()) {
+ LLVM_DEBUG(llvm::dbgs() << "Failed to topologically sort nodes\n");
+ return failure();
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Topologically sorted nodes:\n";
+ for (auto *node : sortedNodes) {
+ llvm::dbgs() << " Node with " << node->ops.size()
+ << " operations: " << node->ops.front()->getName() << "\n";
+ }
+ });
+
+ // TODO: Implement vectorization logic:
+ // 1. Process nodes in topological order
+ // 2. For each node:
+ // a. Check if all operands are vectorized
+ // b. Create vector operation
+ // c. Replace scalar operations with vector operation
+ // 3. Handle memory operations (loads/stores) specially
+ // 4. Update use-def chains
+
+ return success();
+ }
+
/// Print the graph structure
[[maybe_unused]] void print() const {
llvm::dbgs() << "SLP Graph Structure:\n";
@@ -513,6 +587,13 @@ void SLPVectorizerPass::runOnOperation() {
// Print the graph structure
LLVM_DEBUG(graph.print());
+
+ // Vectorize the graph
+ IRRewriter rewriter(&getContext());
+ if (failed(graph.vectorize(rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
+ return signalPassFailure();
+ }
});
}
>From 18666e164ac6b5a512748d01739074ff38ee178c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:30:01 +0200
Subject: [PATCH 13/52] codegen
---
.../Vector/Transforms/SLPVectorizer.cpp | 68 +++++++++++++++----
1 file changed, 56 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 8f0137a12d07b..095ad4f11a91a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -132,6 +132,10 @@ extractContiguousGroups(const MemoryOpGroup &group) {
return result;
}
+static bool isVectorizable(Operation *op) {
+ return OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1;
+}
+
/// A node in the SLP graph representing a group of vectorizable operations
struct SLPGraphNode {
SmallVector<Operation *> ops;
@@ -255,14 +259,58 @@ class SLPGraph {
}
});
- // TODO: Implement vectorization logic:
- // 1. Process nodes in topological order
- // 2. For each node:
- // a. Check if all operands are vectorized
- // b. Create vector operation
- // c. Replace scalar operations with vector operation
- // 3. Handle memory operations (loads/stores) specially
- // 4. Update use-def chains
+ IRMapping mapping;
+ for (auto *node : sortedNodes) {
+ if (node->users.empty() && node->operands.empty())
+ continue;
+
+ Operation *op = node->ops.front();
+ rewriter.setInsertionPoint(op);
+ Location loc = op->getLoc();
+ int64_t numElements = node->ops.size();
+ if (auto load = dyn_cast<memref::LoadOp>(op)) {
+ auto vecType =
+ VectorType::get(numElements, load.getMemRefType().getElementType());
+ Value result = rewriter.create<vector::LoadOp>(
+ loc, vecType, load.getMemRef(), load.getIndices());
+ mapping.map(load.getResult(), result);
+ } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
+ Value val = mapping.lookupOrDefault(store.getValueToStore());
+ rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
+ store.getIndices());
+ } else if (isVectorizable(op)) {
+ auto vecType =
+ VectorType::get(numElements, op->getResultTypes().front());
+ for (auto [i, operand] : llvm::enumerate(op->getOperands())) {
+ if (getNodeForOp(operand.getDefiningOp()))
+ continue;
+
+ SmallVector<Value> args;
+ for (Operation *defOp : node->ops)
+ args.push_back(defOp->getOperand(i));
+
+ Value vector =
+ rewriter.create<vector::FromElementsOp>(loc, vecType, args);
+ mapping.map(operand, vector);
+ }
+
+ Operation *newOp = rewriter.clone(*op, mapping);
+ auto resVectorType =
+ VectorType::get(numElements, op->getResultTypes().front());
+ newOp->getResult(0).setType(resVectorType);
+
+ mapping.map(op->getResults(), newOp->getResults());
+ } else {
+ op->emitError("unsupported operation");
+ return failure();
+ }
+ }
+
+ for (auto *node : llvm::reverse(sortedNodes)) {
+ for (Operation *op : node->ops) {
+ rewriter.eraseOp(op);
+ }
+ }
return success();
}
@@ -343,10 +391,6 @@ struct SLPVectorizerPass
SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
};
-static bool isVectorizable(Operation *op) {
- return OpTrait::hasElementwiseMappableTraits(op);
-}
-
using Fingerprint = std::array<uint8_t, 20>;
template <typename T>
>From 7d2d82583bc95142b4dcc9970369fad70ec667ed Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:53:41 +0200
Subject: [PATCH 14/52] test
---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index a07dd05dd16aa..266008e53ea43 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,10 +1,13 @@
// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
-// CHECK-LABEL: func @basic_slp
-func.func @basic_slp(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
- // CHECK: vector.transfer_read
- // CHECK: arith.addi
- // CHECK: vector.transfer_write
+// CHECK-LABEL: func @read_read_add_write
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
>From 5274bd9fb3ac9872f8f88e7f9c7368bb4c13acba Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 15:57:06 +0200
Subject: [PATCH 15/52] test
---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 25 +++++++++++++++++++++
1 file changed, 25 insertions(+)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 266008e53ea43..28a255f90a869 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,5 +1,30 @@
// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
+// CHECK-LABEL: func @read_write
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ memref.store %0, %arg0[%c0] : memref<8xi32>
+ memref.store %1, %arg0[%c1] : memref<8xi32>
+ memref.store %2, %arg0[%c2] : memref<8xi32>
+ memref.store %3, %arg0[%c3] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
>From e0f0c295a39895b707f9ca9cf66c02945baeed1d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 16:25:45 +0200
Subject: [PATCH 16/52] fixes
---
.../Vector/Transforms/SLPVectorizer.cpp | 49 +++++++++++++------
mlir/test/Dialect/Vector/slp-vectorize.mlir | 36 ++++++++++++++
2 files changed, 70 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 095ad4f11a91a..a40131a1b10ff 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -264,24 +264,13 @@ class SLPGraph {
if (node->users.empty() && node->operands.empty())
continue;
+ int64_t numElements = node->ops.size();
Operation *op = node->ops.front();
rewriter.setInsertionPoint(op);
Location loc = op->getLoc();
- int64_t numElements = node->ops.size();
- if (auto load = dyn_cast<memref::LoadOp>(op)) {
- auto vecType =
- VectorType::get(numElements, load.getMemRefType().getElementType());
- Value result = rewriter.create<vector::LoadOp>(
- loc, vecType, load.getMemRef(), load.getIndices());
- mapping.map(load.getResult(), result);
- } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
- Value val = mapping.lookupOrDefault(store.getValueToStore());
- rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
- store.getIndices());
- } else if (isVectorizable(op)) {
- auto vecType =
- VectorType::get(numElements, op->getResultTypes().front());
- for (auto [i, operand] : llvm::enumerate(op->getOperands())) {
+
+ auto handleNonVectorInputs = [&](ValueRange operands) {
+ for (auto [i, operand] : llvm::enumerate(operands)) {
if (getNodeForOp(operand.getDefiningOp()))
continue;
@@ -289,17 +278,47 @@ class SLPGraph {
for (Operation *defOp : node->ops)
args.push_back(defOp->getOperand(i));
+ auto vecType = VectorType::get(numElements, operand.getType());
Value vector =
rewriter.create<vector::FromElementsOp>(loc, vecType, args);
mapping.map(operand, vector);
}
+ };
+
+ auto handleNonVectorOutputs = [&](Value newResult) {
+ for (auto [i, result] : llvm::enumerate(node->ops)) {
+ for (OpOperand &use : result->getUses()) {
+ Operation *useOwner = use.getOwner();
+ if (getNodeForOp(useOwner))
+ continue;
+
+ Value elem = rewriter.create<vector::ExtractOp>(loc, newResult, i);
+ use.set(elem);
+ }
+ }
+ };
+ if (auto load = dyn_cast<memref::LoadOp>(op)) {
+ auto vecType =
+ VectorType::get(numElements, load.getMemRefType().getElementType());
+ Value result = rewriter.create<vector::LoadOp>(
+ loc, vecType, load.getMemRef(), load.getIndices());
+ mapping.map(load.getResult(), result);
+ handleNonVectorOutputs(result);
+ } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
+ handleNonVectorInputs(store.getValueToStore());
+ Value val = mapping.lookupOrDefault(store.getValueToStore());
+ rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
+ store.getIndices());
+ } else if (isVectorizable(op)) {
+ handleNonVectorInputs(op->getOperands());
Operation *newOp = rewriter.clone(*op, mapping);
auto resVectorType =
VectorType::get(numElements, op->getResultTypes().front());
newOp->getResult(0).setType(resVectorType);
mapping.map(op->getResults(), newOp->getResults());
+ handleNonVectorOutputs(newOp->getResult(0));
} else {
op->emitError("unsupported operation");
return failure();
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 28a255f90a869..2b2b91d667e00 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
+
// CHECK-LABEL: func @read_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
@@ -24,6 +25,41 @@ func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
return
}
+// CHECK-LABEL: func @read_read_add
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: %[[R0:.*]] = vector.extract %[[RES]][0] : i32 from vector<4xi32>
+ // CHECK: %[[R1:.*]] = vector.extract %[[RES]][1] : i32 from vector<4xi32>
+ // CHECK: %[[R2:.*]] = vector.extract %[[RES]][2] : i32 from vector<4xi32>
+ // CHECK: %[[R3:.*]] = vector.extract %[[RES]][3] : i32 from vector<4xi32>
+ // CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]] : i32, i32, i32, i32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+ %8 = arith.addi %0, %4 : i32
+ %9 = arith.addi %1, %5 : i32
+ %10 = arith.addi %2, %6 : i32
+ %11 = arith.addi %3, %7 : i32
+
+ return %8, %9, %10, %11 : i32, i32, i32, i32
+}
+
// CHECK-LABEL: func @read_read_add_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
>From 659825b990e77797fcd989c3f9681048c557e949 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 17:41:39 +0200
Subject: [PATCH 17/52] fixes
---
.../Vector/Transforms/SLPVectorizer.cpp | 56 +++++++++++++++++--
mlir/test/Dialect/Vector/slp-vectorize.mlir | 29 ++++++++++
2 files changed, 79 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index a40131a1b10ff..ab0b3f549192f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -259,9 +259,13 @@ class SLPGraph {
}
});
+ auto isGoodNode = [&](SLPGraphNode *node) {
+ return node->users.empty() && node->operands.empty();
+ };
+
IRMapping mapping;
for (auto *node : sortedNodes) {
- if (node->users.empty() && node->operands.empty())
+ if (isGoodNode(node))
continue;
int64_t numElements = node->ops.size();
@@ -326,6 +330,9 @@ class SLPGraph {
}
for (auto *node : llvm::reverse(sortedNodes)) {
+ if (isGoodNode(node))
+ continue;
+
for (Operation *op : node->ops) {
rewriter.eraseOp(op);
}
@@ -560,10 +567,47 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
auto *newNode = graph.addNode(currentOps);
graph.addEdge(node, newNode);
- for (Operation *op : currentOps) {
+ for (Operation *op : currentOps)
fingerprints.invalidate(op);
+
+ worklist.push_back(newNode);
+ };
+
+ auto processOperands = [&](SLPGraphNode *node, Value operand, int64_t index) {
+ Operation *srcOp = operand.getDefiningOp();
+ if (!srcOp)
+ return;
+
+ auto *existingNode = graph.getNodeForOp(srcOp);
+ if (existingNode) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " Adding edge from " << srcOp->getName() << " to "
+ << node->ops.front()->getName() << "\n");
+ graph.addEdge(existingNode, node);
+ return;
+ }
+
+ if (!isVectorizable(srcOp))
+ return;
+
+ SmallVector<Operation *> currentOps;
+ currentOps.emplace_back(srcOp);
+ for (Operation *op : ArrayRef(node->ops).drop_front()) {
+ Operation *otherOp = op->getOperand(index).getDefiningOp();
+ if (!otherOp || !isEquivalent(otherOp, srcOp))
+ break;
+
+ currentOps.push_back(otherOp);
}
+ if (currentOps.size() == 1)
+ return;
+
+ auto *newNode = graph.addNode(currentOps);
+ graph.addEdge(newNode, node);
+ for (Operation *op : currentOps)
+ fingerprints.invalidate(op);
+
worklist.push_back(newNode);
};
@@ -574,11 +618,11 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
<< node->ops.front()->getName() << "\n");
Operation *op = node->ops.front();
- for (OpOperand &use : op->getUses()) {
+ for (OpOperand &use : op->getUses())
processUse(node, use);
- LLVM_DEBUG(llvm::dbgs() << " Processing use in operation: "
- << use.getOwner()->getName() << "\n");
- }
+
+ for (auto [i, operand] : llvm::enumerate(op->getOperands()))
+ processOperands(node, operand, i);
}
return graph;
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 2b2b91d667e00..036e1fcbed5d5 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -60,6 +60,35 @@ func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i3
return %8, %9, %10, %11 : i32, i32, i32, i32
}
+// CHECK-LABEL: func @add_write
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: memref<8xi32>)
+func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
+ %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32,
+ %arg8: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] : vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.from_elements %[[ARG4]], %[[ARG5]], %[[ARG6]], %[[ARG7]] : vector<4xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %0, %1 : vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG8]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %8 = arith.addi %arg0, %arg4 : i32
+ %9 = arith.addi %arg1, %arg5 : i32
+ %10 = arith.addi %arg2, %arg6 : i32
+ %11 = arith.addi %arg3, %arg7 : i32
+
+ memref.store %8, %arg8[%c0] : memref<8xi32>
+ memref.store %9, %arg8[%c1] : memref<8xi32>
+ memref.store %10, %arg8[%c2] : memref<8xi32>
+ memref.store %11, %arg8[%c3] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
>From edd6cb83352f9716d630ff9be7fe614c4e693605 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:07:37 +0200
Subject: [PATCH 18/52] handle size mismatch
---
.../Vector/Transforms/SLPVectorizer.cpp | 21 +++++++
mlir/test/Dialect/Vector/slp-vectorize.mlir | 62 ++++++++++++++++++-
2 files changed, 82 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ab0b3f549192f..f54a9aba0e6c0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -302,6 +302,16 @@ class SLPGraph {
}
};
+ auto handleVecSizeMismatch = [&](Value arg) -> Value {
+ auto srcType = cast<VectorType>(arg.getType());
+ assert(srcType.getRank() == 1);
+ if (srcType.getDimSize(0) == numElements)
+ return arg;
+
+ return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, 0,
+ numElements, 1);
+ };
+
if (auto load = dyn_cast<memref::LoadOp>(op)) {
auto vecType =
VectorType::get(numElements, load.getMemRefType().getElementType());
@@ -312,6 +322,7 @@ class SLPGraph {
} else if (auto store = dyn_cast<memref::StoreOp>(op)) {
handleNonVectorInputs(store.getValueToStore());
Value val = mapping.lookupOrDefault(store.getValueToStore());
+ val = handleVecSizeMismatch(val);
rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
store.getIndices());
} else if (isVectorizable(op)) {
@@ -319,6 +330,15 @@ class SLPGraph {
Operation *newOp = rewriter.clone(*op, mapping);
auto resVectorType =
VectorType::get(numElements, op->getResultTypes().front());
+
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(newOp);
+ for (OpOperand &operand : newOp->getOpOperands()) {
+ Value newOperand = handleVecSizeMismatch(operand.get());
+ operand.set(newOperand);
+ }
+ }
newOp->getResult(0).setType(resVectorType);
mapping.map(op->getResults(), newOp->getResults());
@@ -701,6 +721,7 @@ void SLPVectorizerPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
return signalPassFailure();
}
+ op->dump();
});
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 036e1fcbed5d5..76592833a78b4 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -25,6 +25,31 @@ func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
return
}
+
+// CHECK-LABEL: func @read_write_size_mistamtch
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write_size_mistamtch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[RES1:.*]] = vector.extract_strided_slice %[[RES]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: vector.store %[[RES1]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ memref.store %0, %arg0[%c0] : memref<8xi32>
+ memref.store %1, %arg0[%c1] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
@@ -60,6 +85,7 @@ func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i3
return %8, %9, %10, %11 : i32, i32, i32, i32
}
+
// CHECK-LABEL: func @add_write
// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: memref<8xi32>)
func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
@@ -89,7 +115,6 @@ func.func @add_write(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32,
}
-
// CHECK-LABEL: func @read_read_add_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
@@ -125,3 +150,38 @@ func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
return
}
+
+
+// CHECK-LABEL: func @read_read_add_write_size_mismatch
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[A1:.*]] = vector.extract_strided_slice %[[A]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[B1:.*]] = vector.extract_strided_slice %[[B]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %[[A1]], %[[B1]] : vector<2xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+ %8 = arith.addi %0, %4 : i32
+ %9 = arith.addi %1, %5 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+
+ return
+}
>From a0d251b3124357f716dea667d95ee37b83c30d0e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:53:54 +0200
Subject: [PATCH 19/52] adjacent indices
---
.../Vector/Transforms/SLPVectorizer.cpp | 98 ++++++++++---------
mlir/test/Dialect/Vector/slp-vectorize.mlir | 25 +++++
2 files changed, 79 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index f54a9aba0e6c0..cc252a0e32c06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,16 +52,43 @@ struct MemoryOpGroup {
bool empty() const { return ops.empty(); }
};
-// Helper function to extract base and index from a memory operation
-std::optional<std::pair<Value, int64_t>> getBaseAndIndex(Operation *op) {
- if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
- if (auto value = getConstantIntValue(loadOp.getIndices().front()))
- return std::make_pair(loadOp.getMemRef(), *value);
- } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
- if (auto value = getConstantIntValue(storeOp.getIndices().front()))
- return std::make_pair(storeOp.getMemRef(), *value);
+static ValueRange getIndices(Operation *op) {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getIndices();
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getIndices();
+ return {};
+}
+
+static Type getElementType(Operation *op) {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getResult().getType();
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getValueToStore().getType();
+ return {};
+}
+
+static bool isAdjacentIndices(Value idx1, Value idx2) {
+ if (auto c1 = getConstantIntValue(idx1)) {
+ if (auto c2 = getConstantIntValue(idx2))
+ return *c1 + 1 == *c2;
}
- return std::nullopt;
+ return false;
+}
+
+static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
+ if (idx1.empty() || idx1.size() != idx2.size())
+ return false;
+
+ if (idx1.drop_back() != idx2.drop_back())
+ return false;
+
+ return isAdjacentIndices(idx1.back(), idx2.back());
+}
+
+static bool isAdjacentIndices(Operation *op1, Operation *op2) {
+ return getElementType(op1) == getElementType(op2) &&
+ isAdjacentIndices(getIndices(op1), getIndices(op2));
}
// Extract contiguous groups from a MemoryOpGroup
@@ -71,64 +98,48 @@ extractContiguousGroups(const MemoryOpGroup &group) {
if (group.ops.empty())
return result;
- // Keep track of which operations we've processed
- DenseSet<Operation *> processedOps;
+ llvm::SmallDenseSet<Operation *> processedOps;
- // Process each operation
for (Operation *op : group.ops) {
- // Skip if we've already processed this operation
if (processedOps.contains(op))
continue;
- // Get base and index of current operation
- auto baseAndIndex = getBaseAndIndex(op);
- if (!baseAndIndex)
- continue;
-
- auto [base, index] = *baseAndIndex;
-
// Start a new group with this operation
result.emplace_back(group.type);
MemoryOpGroup ¤tGroup = result.back();
- currentGroup.ops.push_back(op);
+ auto ¤tOps = currentGroup.ops;
+ currentOps.push_back(op);
processedOps.insert(op);
- LLVM_DEBUG(llvm::dbgs() << "Starting new group at base " << base
- << " index " << index << "\n");
-
- // Try to find operations with adjacent indices
bool foundMore;
do {
foundMore = false;
- // Look for operations with index+1
for (Operation *otherOp : group.ops) {
if (processedOps.contains(otherOp))
continue;
- auto otherBaseAndIndex = getBaseAndIndex(otherOp);
- if (!otherBaseAndIndex)
- continue;
-
- auto [otherBase, otherIndex] = *otherBaseAndIndex;
-
- // Check if this operation has the same base and adjacent index
- if (otherBase == base && otherIndex == currentGroup.ops.size()) {
- currentGroup.ops.push_back(otherOp);
+ Operation *firstOp = currentOps.front();
+ Operation *lastOp = currentOps.back();
+ if (isAdjacentIndices(otherOp, firstOp)) {
+ currentOps.insert(currentOps.begin(), otherOp);
+ processedOps.insert(otherOp);
+ foundMore = true;
+ } else if (isAdjacentIndices(lastOp, otherOp)) {
+ currentOps.push_back(otherOp);
processedOps.insert(otherOp);
foundMore = true;
- LLVM_DEBUG(llvm::dbgs()
- << "Added operation with index " << otherIndex << "\n");
- break;
}
}
} while (foundMore);
- }
- // Remove empty groups
- result.erase(std::remove_if(result.begin(), result.end(),
- [](const MemoryOpGroup &g) { return g.empty(); }),
- result.end());
+ if (currentOps.size() <= 1) {
+ result.pop_back();
+ continue;
+ }
+ LLVM_DEBUG(llvm::dbgs() << "Extracted contiguous group with "
+ << currentGroup.ops.size() << " operations\n");
+ }
return result;
}
@@ -721,7 +732,6 @@ void SLPVectorizerPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
return signalPassFailure();
}
- op->dump();
});
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 76592833a78b4..6be405ad078b9 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -50,6 +50,31 @@ func.func @read_write_size_mistamtch(%arg0: memref<8xi32>, %arg1: memref<8xi32>)
}
+// CHECK-LABEL: func @read_write_interleaved
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+
+ memref.store %1, %arg0[%c1] : memref<8xi32>
+ memref.store %0, %arg0[%c0] : memref<8xi32>
+ memref.store %3, %arg0[%c3] : memref<8xi32>
+ memref.store %2, %arg0[%c2] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
>From 785f7568349377d9bfbefda73c5fc92c9122184d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 18:57:52 +0200
Subject: [PATCH 20/52] test
---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 38 +++++++++++++++++++++
1 file changed, 38 insertions(+)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 6be405ad078b9..9c5005f807c71 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -210,3 +210,41 @@ func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref
return
}
+
+
+// CHECK-LABEL: func @read_read_add_write_interleaved
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+ %11 = arith.addi %3, %7 : i32
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %8 = arith.addi %0, %4 : i32
+
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %10 = arith.addi %2, %6 : i32
+
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %9 = arith.addi %1, %5 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+
+ return
+}
>From 6c18d433545d8854688caa140afe7fc18c961add Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:07:26 +0200
Subject: [PATCH 21/52] fixes and test
---
.../Vector/Transforms/SLPVectorizer.cpp | 11 +++-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 54 +++++++++++++++++++
2 files changed, 64 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index cc252a0e32c06..3ff46093d9fbe 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,6 +52,14 @@ struct MemoryOpGroup {
bool empty() const { return ops.empty(); }
};
+static Value getBase(Operation *op) {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getMemRef();
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getMemRef();
+ return {};
+}
+
static ValueRange getIndices(Operation *op) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getIndices();
@@ -87,7 +95,8 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
}
static bool isAdjacentIndices(Operation *op1, Operation *op2) {
- return getElementType(op1) == getElementType(op2) &&
+ return getBase(op1) == getBase(op2) &&
+ getElementType(op1) == getElementType(op2) &&
isAdjacentIndices(getIndices(op1), getIndices(op2));
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 9c5005f807c71..820fbf2d260cd 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -248,3 +248,57 @@ func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8
return
}
+
+
+// CHECK-LABEL: func @read_read_add_add_write
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>
+// CHECK-SAME: , %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
+ %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[ADD1:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: %[[C:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] : vector<4xi32>
+ // CHECK: %[[ADD2:.*]] = arith.addi %[[A]], %[[C]] : vector<4xi32>
+ // CHECK: vector.store %[[ADD1]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: vector.store %[[ADD2]], %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+ %8 = arith.addi %0, %4 : i32
+ %12 = arith.addi %0, %arg2 : i32
+
+ %13 = arith.addi %1, %arg3 : i32
+ %9 = arith.addi %1, %5 : i32
+
+ %10 = arith.addi %2, %6 : i32
+ %14 = arith.addi %2, %arg4 : i32
+
+ %15 = arith.addi %3, %arg5 : i32
+ %11 = arith.addi %3, %7 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
+
+ memref.store %12, %arg1[%c0] : memref<8xi32>
+ memref.store %13, %arg1[%c1] : memref<8xi32>
+ memref.store %14, %arg1[%c2] : memref<8xi32>
+ memref.store %15, %arg1[%c3] : memref<8xi32>
+
+ return
+}
>From 3e912a3d5ab5003819eafec5d17a957ad3bba9f5 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:22:03 +0200
Subject: [PATCH 22/52] better side effects handling
---
.../Vector/Transforms/SLPVectorizer.cpp | 94 +++++++++++--------
1 file changed, 55 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3ff46093d9fbe..6cb6faa486702 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -52,6 +52,61 @@ struct MemoryOpGroup {
bool empty() const { return ops.empty(); }
};
+static bool isReadOp(Operation *op) {
+ auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
+ if (!effectInterface)
+ return true;
+
+ return effectInterface.hasEffect<MemoryEffects::Read>();
+}
+
+static bool isWriteOp(Operation *op) {
+ auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
+ if (!effectInterface)
+ return true;
+
+ return effectInterface.hasEffect<MemoryEffects::Write>();
+}
+
+/// Collect all memory operations in the block into groups.
+/// Each group contains either all loads or all stores, uninterrupted by
+/// operations of the other type.
+static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
+ SmallVector<MemoryOpGroup> groups;
+ MemoryOpGroup *currentGroup = nullptr;
+
+ for (Operation &op : block) {
+ if (currentGroup) {
+ if (currentGroup->isLoadGroup() && isWriteOp(&op)) {
+ currentGroup = nullptr;
+ } else if (currentGroup->isStoreGroup() && isReadOp(&op)) {
+ currentGroup = nullptr;
+ }
+ }
+
+ if (!isa<memref::LoadOp, memref::StoreOp>(op))
+ continue;
+
+ bool isLoad = isReadOp(&op);
+ MemoryOpGroup::Type type =
+ isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
+
+ if (!currentGroup) {
+ groups.emplace_back(type);
+ currentGroup = &groups.back();
+ }
+
+ currentGroup->ops.push_back(&op);
+ }
+
+ // Remove empty groups
+ groups.erase(std::remove_if(groups.begin(), groups.end(),
+ [](const MemoryOpGroup &g) { return g.empty(); }),
+ groups.end());
+
+ return groups;
+}
+
static Value getBase(Operation *op) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getMemRef();
@@ -449,12 +504,6 @@ class SLPGraph {
struct SLPVectorizerPass
: public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
void runOnOperation() override;
-
-private:
- /// Collect all memory operations in the block into groups.
- /// Each group contains either all loads or all stores, uninterrupted by
- /// operations of the other type.
- SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block);
};
using Fingerprint = std::array<uint8_t, 20>;
@@ -668,39 +717,6 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
return graph;
}
-SmallVector<MemoryOpGroup>
-SLPVectorizerPass::collectMemoryOpGroups(Block &block) {
- SmallVector<MemoryOpGroup> groups;
- MemoryOpGroup *currentGroup = nullptr;
-
- for (Operation &op : block) {
- // Skip non-memory operations
- if (!isa<memref::LoadOp, memref::StoreOp>(op))
- continue;
-
- bool isLoad = isa<memref::LoadOp>(op);
- MemoryOpGroup::Type type =
- isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
-
- // Start a new group if:
- // 1. We don't have a current group, or
- // 2. The current operation is a different type than the current group
- if (!currentGroup || currentGroup->type != type) {
- groups.emplace_back(type);
- currentGroup = &groups.back();
- }
-
- currentGroup->ops.push_back(&op);
- }
-
- // Remove empty groups
- groups.erase(std::remove_if(groups.begin(), groups.end(),
- [](const MemoryOpGroup &g) { return g.empty(); }),
- groups.end());
-
- return groups;
-}
-
void SLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
>From c851b5da448d4e25071495d773899205d37d2614 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 19:44:47 +0200
Subject: [PATCH 23/52] cleanup
---
.../mlir/Dialect/Vector/Transforms/Passes.h | 3 --
.../mlir/Dialect/Vector/Transforms/Passes.td | 14 ++++--
.../Vector/Transforms/SLPVectorizer.cpp | 49 ++++++++++++-------
mlir/test/Dialect/Vector/slp-vectorize.mlir | 2 +-
4 files changed, 43 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 43112f084dc60..5667f4fa95ace 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -25,9 +25,6 @@ std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
VectorMultiReductionLowering option =
VectorMultiReductionLowering::InnerParallel);
-/// Creates a pass that implements the SLP vectorizer.
-std::unique_ptr<Pass> createSLPVectorizerPass();
-
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 94ccd61cb5170..d5c31c9f78409 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,15 +34,21 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
];
}
-def SLPVectorizer : Pass<"slp-vectorizer", "ModuleOp"> {
+def GreedySLPVectorizer : Pass<"greedy-slp-vectorizer"> {
let summary = "SLP Vectorizer Pass";
let description = [{
This pass implements the SLP (Superword Level Parallelism) vectorizer.
It detects consecutive operations that can be put together into vector
- operations. The pass works bottom-up, across basic blocks, in search of
- scalars to combine.
+ operations. The pass works bi-directionaly, starting from reads or stores,
+ in search of scalars to combine.
+
+ This is greedy vectorizer, it doesn't have any cost model (yet) and it tries
+ to create vector ops if we have at least 2 potential ops.
+
+ It doesn't check if target actually supports resulted vectors either, user
+ will need a follow up pass which will split large and/or unaliggned vectors
+ into sizes actually supported by the target.
}];
- let constructor = "mlir::vector::createSLPVectorizerPass()";
let dependentDialects = ["mlir::vector::VectorDialect"];
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 6cb6faa486702..d7c2dc3845cac 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -27,7 +27,7 @@
namespace mlir {
namespace vector {
-#define GEN_PASS_DEF_SLPVECTORIZER
+#define GEN_PASS_DEF_GREEDYSLPVECTORIZER
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
} // namespace vector
} // namespace mlir
@@ -115,6 +115,19 @@ static Value getBase(Operation *op) {
return {};
}
+static bool isContiguousLastDim(Value val) {
+ auto memrefType = dyn_cast<MemRefType>(val.getType());
+ if (!memrefType)
+ return false;
+
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ if (failed(memrefType.getStridesAndOffset(strides, offset)))
+ return false;
+
+ return !strides.empty() && strides.back() == 1;
+}
+
static ValueRange getIndices(Operation *op) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getIndices();
@@ -150,8 +163,15 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
}
static bool isAdjacentIndices(Operation *op1, Operation *op2) {
- return getBase(op1) == getBase(op2) &&
- getElementType(op1) == getElementType(op2) &&
+ Value base1 = getBase(op1);
+ Value base2 = getBase(op2);
+ if (base1 != base2)
+ return false;
+
+ if (!isContiguousLastDim(base1))
+ return false;
+
+ return getElementType(op1) == getElementType(op2) &&
isAdjacentIndices(getIndices(op1), getIndices(op2));
}
@@ -498,11 +518,9 @@ class SLPGraph {
llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
};
-/// This pass implements the SLP vectorizer. It detects consecutive operations
-/// that can be put together into vector operations. The pass works bottom-up,
-/// across basic blocks, in search of scalars to combine.
-struct SLPVectorizerPass
- : public mlir::vector::impl::SLPVectorizerBase<SLPVectorizerPass> {
+struct GreedySLPVectorizerPass
+ : public mlir::vector::impl::GreedySLPVectorizerBase<
+ GreedySLPVectorizerPass> {
void runOnOperation() override;
};
@@ -717,11 +735,11 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
return graph;
}
-void SLPVectorizerPass::runOnOperation() {
+void GreedySLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
// Walk all blocks recursively
- op->walk([&](Block *block) {
+ op->walk([&](Block *block) -> WalkResult {
LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
<< block->getParentOp()->getName() << "\n");
@@ -747,21 +765,18 @@ void SLPVectorizerPass::runOnOperation() {
// Build the SLP graph from root groups
SLPGraph graph = buildSLPGraph(rootGroups);
-
- // Print the graph structure
LLVM_DEBUG(graph.print());
// Vectorize the graph
IRRewriter rewriter(&getContext());
if (failed(graph.vectorize(rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
- return signalPassFailure();
+ signalPassFailure();
+ return WalkResult::interrupt();
}
+
+ return WalkResult::advance();
});
}
} // namespace
-
-std::unique_ptr<Pass> mlir::vector::createSLPVectorizerPass() {
- return std::make_unique<SLPVectorizerPass>();
-}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 820fbf2d260cd..2e9298d11ed05 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --slp-vectorizer | FileCheck %s
+// RUN: mlir-opt %s --greedy-slp-vectorizer | FileCheck %s
// CHECK-LABEL: func @read_write
>From bcabdf75314c15271652e9d34245e06a199dca5c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 20:10:02 +0200
Subject: [PATCH 24/52] cleanup
---
.../Vector/Transforms/SLPVectorizer.cpp | 80 +++++++++++++------
1 file changed, 57 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index d7c2dc3845cac..24059ec355b30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -49,7 +49,6 @@ struct MemoryOpGroup {
bool isStoreGroup() const { return type == Type::Store; }
size_t size() const { return ops.size(); }
- bool empty() const { return ops.empty(); }
};
static bool isReadOp(Operation *op) {
@@ -99,11 +98,6 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
currentGroup->ops.push_back(&op);
}
- // Remove empty groups
- groups.erase(std::remove_if(groups.begin(), groups.end(),
- [](const MemoryOpGroup &g) { return g.empty(); }),
- groups.end());
-
return groups;
}
@@ -144,14 +138,19 @@ static Type getElementType(Operation *op) {
return {};
}
+/// Check if two indices are consecutive, i.e fastest index differs by 1.
static bool isAdjacentIndices(Value idx1, Value idx2) {
if (auto c1 = getConstantIntValue(idx1)) {
if (auto c2 = getConstantIntValue(idx2))
return *c1 + 1 == *c2;
}
+
+ // TODO: Check arith.add, affine.apply, etc
return false;
}
+/// Check if two ranges of indices are consecutive, i.e fastest index differs
+/// by 1 and all other indices are the same.
static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
if (idx1.empty() || idx1.size() != idx2.size())
return false;
@@ -162,7 +161,10 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
return isAdjacentIndices(idx1.back(), idx2.back());
}
-static bool isAdjacentIndices(Operation *op1, Operation *op2) {
+/// Check if two operations are adjacent and can be combined into a vector op.
+/// This is done by checking if the base memrefs are the same, the last
+/// dimension is contiguous, and the element types and indices are compatible
+static bool isAdjacentOps(Operation *op1, Operation *op2) {
Value base1 = getBase(op1);
Value base2 = getBase(op2);
if (base1 != base2)
@@ -195,6 +197,8 @@ extractContiguousGroups(const MemoryOpGroup &group) {
currentOps.push_back(op);
processedOps.insert(op);
+ // Keep adding ops to the beginning or end of the current group until no
+ // more ops can be added.
bool foundMore;
do {
foundMore = false;
@@ -204,11 +208,11 @@ extractContiguousGroups(const MemoryOpGroup &group) {
Operation *firstOp = currentOps.front();
Operation *lastOp = currentOps.back();
- if (isAdjacentIndices(otherOp, firstOp)) {
+ if (isAdjacentOps(otherOp, firstOp)) {
currentOps.insert(currentOps.begin(), otherOp);
processedOps.insert(otherOp);
foundMore = true;
- } else if (isAdjacentIndices(lastOp, otherOp)) {
+ } else if (isAdjacentOps(lastOp, otherOp)) {
currentOps.push_back(otherOp);
processedOps.insert(otherOp);
foundMore = true;
@@ -222,7 +226,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
}
LLVM_DEBUG(llvm::dbgs() << "Extracted contiguous group with "
- << currentGroup.ops.size() << " operations\n");
+ << currentGroup.size() << " operations\n");
}
return result;
}
@@ -241,6 +245,8 @@ struct SLPGraphNode {
SLPGraphNode() = default;
SLPGraphNode(ArrayRef<Operation *> operations)
: ops(operations.begin(), operations.end()) {}
+
+ size_t size() const { return ops.size(); }
};
/// A graph of vectorizable operations
@@ -349,7 +355,7 @@ class SLPGraph {
LLVM_DEBUG({
llvm::dbgs() << "Topologically sorted nodes:\n";
for (auto *node : sortedNodes) {
- llvm::dbgs() << " Node with " << node->ops.size()
+ llvm::dbgs() << " Node with " << node->size()
<< " operations: " << node->ops.front()->getName() << "\n";
}
});
@@ -363,7 +369,7 @@ class SLPGraph {
if (isGoodNode(node))
continue;
- int64_t numElements = node->ops.size();
+ int64_t numElements = node->size();
Operation *op = node->ops.front();
rewriter.setInsertionPoint(op);
Location loc = op->getLoc();
@@ -467,15 +473,15 @@ class SLPGraph {
if (!node->isRoot)
continue;
llvm::dbgs() << " "
- << (isa<memref::LoadOp>(node->ops[0]) ? "LOAD" : "STORE")
- << " group with " << node->ops.size() << " operations:\n";
+ << (isa<memref::LoadOp>(node->ops.front()) ? "LOAD"
+ : "STORE")
+ << " group with " << node->size() << " operations:\n";
for (auto *op : node->ops) {
llvm::dbgs() << " " << *op << "\n";
}
llvm::dbgs() << " Users: ";
for (auto *user : node->users) {
- llvm::dbgs() << "\n Group with " << user->ops.size()
- << " operations:";
+ llvm::dbgs() << "\n Group with " << user->size() << " operations:";
for (auto *op : user->ops) {
llvm::dbgs() << "\n " << *op;
}
@@ -488,13 +494,13 @@ class SLPGraph {
for (const auto &node : nodes) {
if (node->isRoot)
continue;
- llvm::dbgs() << " Group with " << node->ops.size() << " operations:\n";
+ llvm::dbgs() << " Group with " << node->size() << " operations:\n";
for (auto *op : node->ops) {
llvm::dbgs() << " " << *op << "\n";
}
llvm::dbgs() << " Operands: ";
for (auto *operand : node->operands) {
- llvm::dbgs() << "\n Group with " << operand->ops.size()
+ llvm::dbgs() << "\n Group with " << operand->size()
<< " operations:";
for (auto *op : operand->ops) {
llvm::dbgs() << "\n " << *op;
@@ -502,8 +508,7 @@ class SLPGraph {
}
llvm::dbgs() << "\n Users: ";
for (auto *user : node->users) {
- llvm::dbgs() << "\n Group with " << user->ops.size()
- << " operations:";
+ llvm::dbgs() << "\n Group with " << user->size() << " operations:";
for (auto *op : user->ops) {
llvm::dbgs() << "\n " << *op;
}
@@ -518,6 +523,28 @@ class SLPGraph {
llvm::SmallDenseMap<Operation *, SLPGraphNode *> opToNode;
};
+/// This pass implements the greedy SLP vectorizer. It detects consecutive
+/// operations that can be put together into vector operations. The pass works
+/// bi-directionaly, starting from reads or stores, in search of scalars to
+/// combine.
+///
+/// Pass is split into multiple steps:
+/// 1. Collect memory operation groups within same block.
+/// Group is either multiple loads uninterrupted by stores or multiple stores
+/// uninterrupted by loads.
+///
+/// 2. Extract contiguous groups from memory operation groups, based on the
+/// ops base memrefs, load/store element types, and indices.
+///
+/// 3. Build SLP graph from contiguous groups. This is done by going both
+/// top-down and bottom-up through uses/operands respectively, starting from
+/// contiguous memory operation groups.
+///
+/// 4. Vectorize SLP graph. This is done by topological sort of the graph and
+/// vectorizing each node in the order of the sort.
+///
+/// Vectorization is done by cloning the operations and mapping the operands and
+/// results.
struct GreedySLPVectorizerPass
: public mlir::vector::impl::GreedySLPVectorizerBase<
GreedySLPVectorizerPass> {
@@ -532,6 +559,10 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
}
+/// SLP vectorizer is bi-directional, so when we go top-down we can can have
+/// multiple users with the same immediate op type, this class tries to compute
+/// fingerprint for such ops based on the entire ops graph to maximize further
+/// scalar ops merging.
struct OperationsFingerprint {
OperationsFingerprint(const SLPGraph &graph) : graph(graph) {}
@@ -606,7 +637,8 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
return true;
}
-/// Build the SLP graph starting from memory operation groups
+/// Build the SLP graph starting from memory operation groups and going both
+/// top-down and bottom-up through uses/operands respectively.
static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
if (rootGroups.empty())
return SLPGraph();
@@ -623,7 +655,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
worklist.push_back(node);
LLVM_DEBUG({
- llvm::dbgs() << "Created root group node with " << node->ops.size()
+ llvm::dbgs() << "Created root group node with " << node->size()
<< " operations of type "
<< (group.isLoadGroup() ? "Load" : "Store") << "\n";
});
@@ -631,6 +663,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
OperationsFingerprint fingerprints(graph);
+ // Process node uses, going top-down.
auto processUse = [&](SLPGraphNode *node, OpOperand &use) {
Operation *user = use.getOwner();
auto *existingNode = graph.getNodeForOp(user);
@@ -680,6 +713,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
worklist.push_back(newNode);
};
+ // Process node operands, going bottom-up.
auto processOperands = [&](SLPGraphNode *node, Value operand, int64_t index) {
Operation *srcOp = operand.getDefiningOp();
if (!srcOp)
@@ -720,7 +754,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
while (!worklist.empty()) {
SLPGraphNode *node = worklist.pop_back_val();
- LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->ops.size()
+ LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->size()
<< " operations, first op: "
<< node->ops.front()->getName() << "\n");
>From 5613035ee349cd4d5e157dbb82eed380ccad3920 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 22:08:35 +0200
Subject: [PATCH 25/52] check arith.add indices
---
.../Vector/Transforms/SLPVectorizer.cpp | 13 ++++-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 54 +++++++++++++++++++
2 files changed, 66 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 24059ec355b30..aa2f3108712f1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -145,7 +145,18 @@ static bool isAdjacentIndices(Value idx1, Value idx2) {
return *c1 + 1 == *c2;
}
- // TODO: Check arith.add, affine.apply, etc
+ if (auto addOp2 = idx2.getDefiningOp<arith::AddIOp>()) {
+ if (addOp2.getLhs() == idx1 && getConstantIntValue(addOp2.getRhs()) == 1)
+ return true;
+
+ if (auto addOp1 = idx1.getDefiningOp<arith::AddIOp>()) {
+ if (addOp1.getLhs() == addOp2.getLhs() &&
+ isAdjacentIndices(addOp1.getRhs(), addOp2.getRhs()))
+ return true;
+ }
+ }
+
+ // TODO: affine.apply, etc
return false;
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 2e9298d11ed05..edb722472995d 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -75,6 +75,60 @@ func.func @read_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
}
+// CHECK-LABEL: func @read_write_add_index
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index)
+func.func @read_write_add_index(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index) {
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %ind1 = arith.addi %arg2, %c1 : index
+ %ind2 = arith.addi %arg2, %c2 : index
+ %ind3 = arith.addi %arg2, %c3 : index
+
+ %0 = memref.load %arg0[%arg2] : memref<8xi32>
+ %1 = memref.load %arg0[%ind1] : memref<8xi32>
+ %2 = memref.load %arg0[%ind2] : memref<8xi32>
+ %3 = memref.load %arg0[%ind3] : memref<8xi32>
+
+ memref.store %0, %arg0[%arg2] : memref<8xi32>
+ memref.store %1, %arg0[%ind1] : memref<8xi32>
+ memref.store %2, %arg0[%ind2] : memref<8xi32>
+ memref.store %3, %arg0[%ind3] : memref<8xi32>
+
+ return
+}
+
+
+// CHECK-LABEL: func @read_write_add_index_interleaved
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index)
+func.func @read_write_add_index_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index) {
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG2]]] : memref<8xi32>, vector<4xi32>
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %ind1 = arith.addi %arg2, %c1 : index
+ %ind2 = arith.addi %arg2, %c2 : index
+ %ind3 = arith.addi %arg2, %c3 : index
+
+ %0 = memref.load %arg0[%arg2] : memref<8xi32>
+ %1 = memref.load %arg0[%ind1] : memref<8xi32>
+ %3 = memref.load %arg0[%ind3] : memref<8xi32>
+ %2 = memref.load %arg0[%ind2] : memref<8xi32>
+
+ memref.store %3, %arg0[%ind3] : memref<8xi32>
+ memref.store %0, %arg0[%arg2] : memref<8xi32>
+ memref.store %1, %arg0[%ind1] : memref<8xi32>
+ memref.store %2, %arg0[%ind2] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
>From a43bc4144f656a575cc703331ef34e6feac422a4 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 23:41:34 +0200
Subject: [PATCH 26/52] fix vecor sizes
---
.../Vector/Transforms/SLPVectorizer.cpp | 24 ++++++----
mlir/test/Dialect/Vector/slp-vectorize.mlir | 47 +++++++++++++++++++
2 files changed, 61 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index aa2f3108712f1..dfd4747f615ee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -12,14 +12,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/SHA1.h"
@@ -371,15 +368,24 @@ class SLPGraph {
}
});
- auto isGoodNode = [&](SLPGraphNode *node) {
+ auto isBadNode = [&](SLPGraphNode *node) {
return node->users.empty() && node->operands.empty();
};
- IRMapping mapping;
+ // Update vec sizes if inputs are smaller.
for (auto *node : sortedNodes) {
- if (isGoodNode(node))
- continue;
+ size_t size = node->size();
+ for (auto *operand : node->operands)
+ size = std::min(size, operand->size());
+
+ node->ops.resize(size);
+ }
+
+ // Remove nodes that are not good (have users or operands)
+ llvm::erase_if(sortedNodes, isBadNode);
+ IRMapping mapping;
+ for (auto *node : sortedNodes) {
int64_t numElements = node->size();
Operation *op = node->ops.front();
rewriter.setInsertionPoint(op);
@@ -462,14 +468,12 @@ class SLPGraph {
}
for (auto *node : llvm::reverse(sortedNodes)) {
- if (isGoodNode(node))
- continue;
-
for (Operation *op : node->ops) {
rewriter.eraseOp(op);
}
}
+ LLVM_DEBUG(llvm::dbgs() << "Vectorization completed successfully\n");
return success();
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index edb722472995d..7ad077d8fd78c 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -356,3 +356,50 @@ func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
return
}
+
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func @read_read_add_write_interleaved_use
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[V0:.*]] = memref.load %[[ARG0]][%[[C3]]] : memref<8xi32>
+ // CHECK: %[[V1:.*]] = memref.load %[[ARG1]][%[[C3]]] : memref<8xi32>
+ // CHECK: call @use(%[[V0]]) : (i32) -> ()
+ // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : i32
+ // CHECK: %[[V3:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+ // CHECK: %[[V4:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+ // CHECK: %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : vector<3xi32>
+ // CHECK: vector.store %[[V5]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
+ // CHECK: memref.store %[[V2]], %[[ARG0]][%[[C3]]] : memref<8xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+ call @use(%3) : (i32) -> ()
+ %11 = arith.addi %3, %7 : i32
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %8 = arith.addi %0, %4 : i32
+
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %10 = arith.addi %2, %6 : i32
+
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %9 = arith.addi %1, %5 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+
+ return
+}
>From dfd44a4b78b389d9703f49fbcfc835e308baeac6 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 18 May 2025 23:48:08 +0200
Subject: [PATCH 27/52] fix op insertion point
---
.../Vector/Transforms/SLPVectorizer.cpp | 12 ++++-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 44 +++++++++++++++++++
2 files changed, 55 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index dfd4747f615ee..ab5bfd94de49c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -255,6 +255,16 @@ struct SLPGraphNode {
: ops(operations.begin(), operations.end()) {}
size_t size() const { return ops.size(); }
+
+ Operation *getEarliestOp() const {
+ assert(!ops.empty() && "empty node");
+ Operation *ret = ops.front();
+ for (Operation *op : ArrayRef(ops).drop_front()) {
+ if (op->isBeforeInBlock(ret))
+ ret = op;
+ }
+ return ret;
+ }
};
/// A graph of vectorizable operations
@@ -388,7 +398,7 @@ class SLPGraph {
for (auto *node : sortedNodes) {
int64_t numElements = node->size();
Operation *op = node->ops.front();
- rewriter.setInsertionPoint(op);
+ rewriter.setInsertionPoint(node->getEarliestOp());
Location loc = op->getLoc();
auto handleNonVectorInputs = [&](ValueRange operands) {
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 7ad077d8fd78c..9d06a1faa07b2 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -403,3 +403,47 @@ func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memr
return
}
+
+
+// CHECK-LABEL: func @read_read_add_write_interleaved_use_add
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V1:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V3:.*]] = vector.extract %[[V2]][3] : i32 from vector<4xi32>
+ // CHECK: %[[V4:.*]] = arith.subi %[[V1]], %[[V3]] : i32
+ // CHECK: %[[V5:.*]] = arith.addi %[[V0]], %[[V2]] : vector<4xi32>
+ // CHECK: vector.store %[[V5]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: call @use(%[[V4]]) : (i32) -> ()
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+ %12 = arith.subi %3, %7 : i32
+ %11 = arith.addi %3, %7 : i32
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %8 = arith.addi %0, %4 : i32
+
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %10 = arith.addi %2, %6 : i32
+
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %9 = arith.addi %1, %5 : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+
+ call @use(%12) : (i32) -> ()
+ return
+}
>From 82da589254683c697e461f12c0e486bbcb6e9de1 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 19 May 2025 11:00:11 +0200
Subject: [PATCH 28/52] check same block
---
mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ab5bfd94de49c..ec1c41dbd7b69 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -659,6 +659,9 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
if (op1->getRawDictionaryAttrs() != op2->getRawDictionaryAttrs())
return false;
+ if (op1->getBlock() != op2->getBlock())
+ return false;
+
return true;
}
>From 5c339976bb985449473a30f7783b1dbc4741efd6 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 12:13:50 +0200
Subject: [PATCH 29/52] cleanup and comments
---
.../Vector/Transforms/SLPVectorizer.cpp | 118 +++++++++++++-----
1 file changed, 90 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ec1c41dbd7b69..c6e20961725c7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -48,7 +48,7 @@ struct MemoryOpGroup {
size_t size() const { return ops.size(); }
};
-static bool isReadOp(Operation *op) {
+static bool maybeReadOp(Operation *op) {
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
if (!effectInterface)
return true;
@@ -56,7 +56,7 @@ static bool isReadOp(Operation *op) {
return effectInterface.hasEffect<MemoryEffects::Read>();
}
-static bool isWriteOp(Operation *op) {
+static bool maybeWriteOp(Operation *op) {
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
if (!effectInterface)
return true;
@@ -72,10 +72,11 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
MemoryOpGroup *currentGroup = nullptr;
for (Operation &op : block) {
+ // Check if current group is interrupted by a read or write op.
if (currentGroup) {
- if (currentGroup->isLoadGroup() && isWriteOp(&op)) {
+ if (currentGroup->isLoadGroup() && maybeWriteOp(&op)) {
currentGroup = nullptr;
- } else if (currentGroup->isStoreGroup() && isReadOp(&op)) {
+ } else if (currentGroup->isStoreGroup() && maybeReadOp(&op)) {
currentGroup = nullptr;
}
}
@@ -83,7 +84,7 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
if (!isa<memref::LoadOp, memref::StoreOp>(op))
continue;
- bool isLoad = isReadOp(&op);
+ bool isLoad = maybeReadOp(&op);
MemoryOpGroup::Type type =
isLoad ? MemoryOpGroup::Type::Load : MemoryOpGroup::Type::Store;
@@ -99,6 +100,7 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
}
static Value getBase(Operation *op) {
+ assert(op && "null op");
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getMemRef();
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
@@ -120,6 +122,7 @@ static bool isContiguousLastDim(Value val) {
}
static ValueRange getIndices(Operation *op) {
+ assert(op && "null op");
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getIndices();
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
@@ -128,6 +131,7 @@ static ValueRange getIndices(Operation *op) {
}
static Type getElementType(Operation *op) {
+ assert(op && "null op");
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getResult().getType();
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
@@ -135,7 +139,7 @@ static Type getElementType(Operation *op) {
return {};
}
-/// Check if two indices are consecutive, i.e fastest index differs by 1.
+/// Check if two indices are consecutive, i.e index1 + 1 == index2.
static bool isAdjacentIndices(Value idx1, Value idx2) {
if (auto c1 = getConstantIntValue(idx1)) {
if (auto c2 = getConstantIntValue(idx2))
@@ -153,7 +157,7 @@ static bool isAdjacentIndices(Value idx1, Value idx2) {
}
}
- // TODO: affine.apply, etc
+ // TODO: Handle affine.apply, etc
return false;
}
@@ -173,6 +177,9 @@ static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
/// This is done by checking if the base memrefs are the same, the last
/// dimension is contiguous, and the element types and indices are compatible
static bool isAdjacentOps(Operation *op1, Operation *op2) {
+ assert(op1 && "null op1");
+ assert(op2 && "null op2");
+
Value base1 = getBase(op1);
Value base2 = getBase(op2);
if (base1 != base2)
@@ -181,8 +188,10 @@ static bool isAdjacentOps(Operation *op1, Operation *op2) {
if (!isContiguousLastDim(base1))
return false;
- return getElementType(op1) == getElementType(op2) &&
- isAdjacentIndices(getIndices(op1), getIndices(op2));
+ if (getElementType(op1) != getElementType(op2))
+ return false;
+
+ return isAdjacentIndices(getIndices(op1), getIndices(op2));
}
// Extract contiguous groups from a MemoryOpGroup
@@ -229,6 +238,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
} while (foundMore);
if (currentOps.size() <= 1) {
+ // Do not vectorize if there is only one op.
result.pop_back();
continue;
}
@@ -256,9 +266,16 @@ struct SLPGraphNode {
size_t size() const { return ops.size(); }
- Operation *getEarliestOp() const {
+ Operation *op() const {
+ assert(!ops.empty() && "empty ops");
+ return ops.front();
+ }
+
+ Operation *getInsertionPoint() const {
+ // Find the toplogically first node, which is not nessesary the first in the
+ // `ops` as `ops` are sorted by their position in vector.
assert(!ops.empty() && "empty node");
- Operation *ret = ops.front();
+ Operation *ret = op();
for (Operation *op : ArrayRef(ops).drop_front()) {
if (op->isBeforeInBlock(ret))
ret = op;
@@ -374,15 +391,20 @@ class SLPGraph {
llvm::dbgs() << "Topologically sorted nodes:\n";
for (auto *node : sortedNodes) {
llvm::dbgs() << " Node with " << node->size()
- << " operations: " << node->ops.front()->getName() << "\n";
+ << " operations: " << node->op()->getName() << "\n";
}
});
auto isBadNode = [&](SLPGraphNode *node) {
+ // Do not vectorize stray nodes which are not connected to any other
+ // nodes.
return node->users.empty() && node->operands.empty();
};
- // Update vec sizes if inputs are smaller.
+ // Update node vec sizes if its inputs vec sizes are smaller.
+ // This is nedeed to handle situations when we have 3->3->4 sizes in tree.
+ // TODO: It maybe possible to reconstruct the larger vec size combining src
+ // smaller vector and scalar arg.
for (auto *node : sortedNodes) {
size_t size = node->size();
for (auto *operand : node->operands)
@@ -391,14 +413,19 @@ class SLPGraph {
node->ops.resize(size);
}
- // Remove nodes that are not good (have users or operands)
llvm::erase_if(sortedNodes, isBadNode);
IRMapping mapping;
for (auto *node : sortedNodes) {
+ // `op` is the node with the smallest index in vector and not the
+ // nessesarily the good insertion point.
+ Operation *op = node->op();
+ Operation *ip = node->getInsertionPoint();
+ if (!ip)
+ return op->emitError("no insertion point found for node");
+
+ rewriter.setInsertionPoint(ip);
int64_t numElements = node->size();
- Operation *op = node->ops.front();
- rewriter.setInsertionPoint(node->getEarliestOp());
Location loc = op->getLoc();
auto handleNonVectorInputs = [&](ValueRange operands) {
@@ -477,6 +504,10 @@ class SLPGraph {
}
}
+ LLVM_DEBUG(llvm::dbgs() << "Erasing original ops\n");
+
+ // As all nodes were cloned, we need to erase the original ops in reverse
+ // topo order to avoid invalidation users.
for (auto *node : llvm::reverse(sortedNodes)) {
for (Operation *op : node->ops) {
rewriter.eraseOp(op);
@@ -498,8 +529,7 @@ class SLPGraph {
if (!node->isRoot)
continue;
llvm::dbgs() << " "
- << (isa<memref::LoadOp>(node->ops.front()) ? "LOAD"
- : "STORE")
+ << (isa<memref::LoadOp>(node->op()) ? "LOAD" : "STORE")
<< " group with " << node->size() << " operations:\n";
for (auto *op : node->ops) {
llvm::dbgs() << " " << *op << "\n";
@@ -588,10 +618,41 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
/// multiple users with the same immediate op type, this class tries to compute
/// fingerprint for such ops based on the entire ops graph to maximize further
/// scalar ops merging.
+///
+/// Example:
+/// ```
+/// %0 = memref.load %arg0[%c0] : memref<8xi32>
+/// %1 = memref.load %arg0[%c1] : memref<8xi32>
+/// %2 = memref.load %arg0[%c2] : memref<8xi32>
+/// %3 = memref.load %arg0[%c3] : memref<8xi32>
+///
+/// %4 = memref.load %arg1[%c0] : memref<8xi32>
+/// %5 = memref.load %arg1[%c1] : memref<8xi32>
+/// %6 = memref.load %arg1[%c2] : memref<8xi32>
+/// %7 = memref.load %arg1[%c3] : memref<8xi32>
+///
+/// %8 = arith.addi %0, %4 : i32
+/// %12 = arith.addi %0, %arg2 : i32
+///
+/// %13 = arith.addi %1, %arg3 : i32
+/// %9 = arith.addi %1, %5 : i32
+///
+/// %10 = arith.addi %2, %6 : i32
+/// %14 = arith.addi %2, %arg4 : i32
+///
+/// %15 = arith.addi %3, %arg5 : i32
+/// %11 = arith.addi %3, %7 : i32
+/// ```
+/// Here each load have multiple uses, in different order, and we want to merge
+/// them in a way that maximizes the number of merged ops.
+///
+/// To achieve this, we compute fingerprint for each op including the other
+/// operands, which will include the other loads in this example.
struct OperationsFingerprint {
OperationsFingerprint(const SLPGraph &graph) : graph(graph) {}
Fingerprint getFingerprint(Operation *op) {
+ assert(op && "null op");
auto it = fingerprints.find(op);
if (it != fingerprints.end())
return it->second;
@@ -653,6 +714,9 @@ struct OperationsFingerprint {
};
static bool isEquivalent(Operation *op1, Operation *op2) {
+ assert(op1 && "null op1");
+ assert(op2 && "null op2");
+
if (op1->getName() != op2->getName())
return false;
@@ -696,9 +760,8 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
Operation *user = use.getOwner();
auto *existingNode = graph.getNodeForOp(user);
if (existingNode) {
- LLVM_DEBUG(llvm::dbgs()
- << " Adding edge from " << node->ops.front()->getName()
- << " to " << user->getName() << "\n");
+ LLVM_DEBUG(llvm::dbgs() << " Adding edge from " << node->op()->getName()
+ << " to " << user->getName() << "\n");
graph.addEdge(node, existingNode);
return;
}
@@ -749,9 +812,8 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
auto *existingNode = graph.getNodeForOp(srcOp);
if (existingNode) {
- LLVM_DEBUG(llvm::dbgs()
- << " Adding edge from " << srcOp->getName() << " to "
- << node->ops.front()->getName() << "\n");
+ LLVM_DEBUG(llvm::dbgs() << " Adding edge from " << srcOp->getName()
+ << " to " << node->op()->getName() << "\n");
graph.addEdge(existingNode, node);
return;
}
@@ -782,11 +844,11 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
while (!worklist.empty()) {
SLPGraphNode *node = worklist.pop_back_val();
- LLVM_DEBUG(llvm::dbgs() << "Processing node with " << node->size()
- << " operations, first op: "
- << node->ops.front()->getName() << "\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << "Processing node with " << node->size()
+ << " operations, first op: " << node->op()->getName() << "\n");
- Operation *op = node->ops.front();
+ Operation *op = node->op();
for (OpOperand &use : op->getUses())
processUse(node, use);
>From d30060c54c75226c9bd9002b0f264bbbcd70d8b0 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 12:28:49 +0200
Subject: [PATCH 30/52] test
---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 58 ++++++++++++++++++++-
1 file changed, 57 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 9d06a1faa07b2..f744098324243 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -131,7 +131,7 @@ func.func @read_write_add_index_interleaved(%arg0: memref<8xi32>, %arg1: memref<
// CHECK-LABEL: func @read_read_add
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
-func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32){
+func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
// CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
@@ -309,6 +309,8 @@ func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8
// CHECK-SAME: , %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
%arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
+ // Each load group have multiple 2 uses (in potentially different order)
+ // make sure we the both were vectorized.
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
// CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
@@ -357,6 +359,60 @@ func.func @read_read_add_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
return
}
+// CHECK-LABEL: func @read_read_add_add
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>
+// CHECK-SAME: , %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @read_read_add_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
+ %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) ->
+ (i32, i32, i32, i32, i32, i32, i32, i32){
+ // Each load group have multiple 2 uses (in potentially different order)
+ // make sure we the both were vectorized.
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[ADD1:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: %[[R0:.*]] = vector.extract %[[ADD1]][0] : i32 from vector<4xi32>
+ // CHECK: %[[R1:.*]] = vector.extract %[[ADD1]][1] : i32 from vector<4xi32>
+ // CHECK: %[[R2:.*]] = vector.extract %[[ADD1]][2] : i32 from vector<4xi32>
+ // CHECK: %[[R3:.*]] = vector.extract %[[ADD1]][3] : i32 from vector<4xi32>
+ // CHECK: %[[C:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] : vector<4xi32>
+ // CHECK: %[[ADD2:.*]] = arith.addi %[[A]], %[[C]] : vector<4xi32>
+ // CHECK: %[[R4:.*]] = vector.extract %[[ADD2]][0] : i32 from vector<4xi32>
+ // CHECK: %[[R5:.*]] = vector.extract %[[ADD2]][1] : i32 from vector<4xi32>
+ // CHECK: %[[R6:.*]] = vector.extract %[[ADD2]][2] : i32 from vector<4xi32>
+ // CHECK: %[[R7:.*]] = vector.extract %[[ADD2]][3] : i32 from vector<4xi32>
+ // CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]] : i32, i32, i32, i32, i32, i32, i32, i32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+ %8 = arith.addi %0, %4 : i32
+ %12 = arith.addi %0, %arg2 : i32
+
+ %13 = arith.addi %1, %arg3 : i32
+ %9 = arith.addi %1, %5 : i32
+
+ %10 = arith.addi %2, %6 : i32
+ %14 = arith.addi %2, %arg4 : i32
+
+ %15 = arith.addi %3, %arg5 : i32
+ %11 = arith.addi %3, %7 : i32
+
+ return %8, %9, %10, %11, %12, %13, %14, %15 : i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+
func.func private @use(i32)
>From a4b9529fbfb460c71465fff58d19710c5f23f39b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 12:38:22 +0200
Subject: [PATCH 31/52] test
---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index f744098324243..293d004879fe5 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -503,3 +503,17 @@ func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1:
call @use(%12) : (i32) -> ()
return
}
+
+
+// CHECK-LABEL: func @negative_single_op
+func.func @negative_single_op(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK-NOT: vector
+ %c0 = arith.constant 0 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %8 = arith.addi %0, %4 : i32
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+
+ return
+}
>From 78a5ed97251236873fe53b5613bc8872e65caece Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 13:45:19 +0200
Subject: [PATCH 32/52] Run until fixed point
---
.../Vector/Transforms/SLPVectorizer.cpp | 101 +++++++++++-------
mlir/test/Dialect/Vector/slp-vectorize.mlir | 49 +++++++++
2 files changed, 109 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index c6e20961725c7..6059f8937e000 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -372,10 +372,11 @@ class SLPGraph {
return result;
}
- /// Vectorize the operations in the graph
- LogicalResult vectorize(IRRewriter &rewriter) {
+ /// Vectorize the operations in the graph.
+ /// Returns number of nodes vectorized or failure if failed.
+ FailureOr<size_t> vectorize(IRRewriter &rewriter) {
if (nodes.empty())
- return success();
+ return 0;
LLVM_DEBUG(llvm::dbgs()
<< "Vectorizing SLP graph with " << nodes.size() << " nodes\n");
@@ -515,7 +516,7 @@ class SLPGraph {
}
LLVM_DEBUG(llvm::dbgs() << "Vectorization completed successfully\n");
- return success();
+ return sortedNodes.size();
}
/// Print the graph structure
@@ -720,7 +721,7 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
if (op1->getName() != op2->getName())
return false;
- if (op1->getRawDictionaryAttrs() != op2->getRawDictionaryAttrs())
+ if (op1->getAttrs() != op2->getAttrs())
return false;
if (op1->getBlock() != op2->getBlock())
@@ -859,48 +860,66 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
return graph;
}
+/// Try to vectorize ops in a block.
+/// Returns number of nodes vectorized or error flag if failed.
+static FailureOr<size_t> tryToVectorizeInBlock(Block &block) {
+ LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
+ << block.getParentOp()->getName() << "\n");
+
+ // Collect memory operation groups
+ SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(block);
+
+ // Process each group to find contiguous sequences
+ SmallVector<MemoryOpGroup> rootGroups;
+ for (const auto &group : groups) {
+ SmallVector<MemoryOpGroup> contiguousGroups =
+ extractContiguousGroups(group);
+ LLVM_DEBUG({
+ llvm::dbgs() << "Found " << contiguousGroups.size()
+ << " contiguous groups in "
+ << (group.isLoadGroup() ? "load" : "store") << " group\n";
+ for (const auto &contigGroup : contiguousGroups) {
+ llvm::dbgs() << " Contiguous group with " << contigGroup.size()
+ << " operations\n";
+ }
+ });
+ rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
+ }
+
+ // Build the SLP graph from root groups
+ SLPGraph graph = buildSLPGraph(rootGroups);
+ LLVM_DEBUG(graph.print());
+
+ // Vectorize the graph
+ IRRewriter rewriter(block.getParentOp()->getContext());
+ FailureOr<size_t> numNodesVectorized = graph.vectorize(rewriter);
+ if (failed(numNodesVectorized))
+ LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
+
+ return numNodesVectorized;
+}
+
void GreedySLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
- // Walk all blocks recursively
- op->walk([&](Block *block) -> WalkResult {
- LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
- << block->getParentOp()->getName() << "\n");
-
- // Collect memory operation groups
- SmallVector<MemoryOpGroup> groups = collectMemoryOpGroups(*block);
-
- // Process each group to find contiguous sequences
- SmallVector<MemoryOpGroup> rootGroups;
- for (const auto &group : groups) {
- SmallVector<MemoryOpGroup> contiguousGroups =
- extractContiguousGroups(group);
- LLVM_DEBUG({
- llvm::dbgs() << "Found " << contiguousGroups.size()
- << " contiguous groups in "
- << (group.isLoadGroup() ? "load" : "store") << " group\n";
- for (const auto &contigGroup : contiguousGroups) {
- llvm::dbgs() << " Contiguous group with " << contigGroup.size()
- << " operations\n";
- }
- });
- rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
- }
+ // Run until fixed point is reached.
+ bool changed;
+ do {
+ changed = false;
+ // Walk all blocks recursively
+ if (op->walk([&](Block *block) -> WalkResult {
+ FailureOr<size_t> numNodesVectorized =
+ tryToVectorizeInBlock(*block);
+ if (failed(numNodesVectorized))
+ return WalkResult::interrupt();
- // Build the SLP graph from root groups
- SLPGraph graph = buildSLPGraph(rootGroups);
- LLVM_DEBUG(graph.print());
+ changed = changed || *numNodesVectorized > 0;
- // Vectorize the graph
- IRRewriter rewriter(&getContext());
- if (failed(graph.vectorize(rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
- signalPassFailure();
- return WalkResult::interrupt();
- }
+ return WalkResult::advance();
+ }).wasInterrupted())
+ return signalPassFailure();
- return WalkResult::advance();
- });
+ } while (changed);
}
} // namespace
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 293d004879fe5..517b2318f773d 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -266,6 +266,55 @@ func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref
}
+// CHECK-LABEL: func @read_read_add_write_attrs_mismatch
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_attrs_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V1:.*]] = vector.extract %[[V0]][2] : i32 from vector<4xi32>
+ // CHECK: %[[V2:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
+ // CHECK: %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V4:.*]] = vector.extract %[[V3]][2] : i32 from vector<4xi32>
+ // CHECK: %[[V5:.*]] = vector.extract %[[V3]][3] : i32 from vector<4xi32>
+ // CHECK: %[[V6:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V7:.*]] = vector.extract_strided_slice %[[V3]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] overflow<nsw> : vector<2xi32>
+ // CHECK: %[[V9:.*]] = vector.from_elements %[[V1]], %[[V2]] : vector<2xi32>
+ // CHECK: %[[V10:.*]] = vector.from_elements %[[V4]], %[[V5]] : vector<2xi32>
+ // CHECK: %[[V11:.*]] = arith.addi %[[V9]], %[[V10]] overflow<nuw> : vector<2xi32>
+ // CHECK: vector.store %[[V8]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: vector.store %[[V11]], %[[ARG0]][%[[C2]]] : memref<8xi32>, vector<2xi32>
+
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+ %8 = arith.addi %0, %4 overflow<nsw> : i32
+ %9 = arith.addi %1, %5 overflow<nsw> : i32
+ %10 = arith.addi %2, %6 overflow<nuw> : i32
+ %11 = arith.addi %3, %7 overflow<nuw> : i32
+
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add_write_interleaved
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add_write_interleaved(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
>From b875a1867746acd5bb0d941171c2908be7d3d52e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 14:16:40 +0200
Subject: [PATCH 33/52] run DCE between interations
---
.../Vector/Transforms/SLPVectorizer.cpp | 24 +++++++++++--------
1 file changed, 14 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 6059f8937e000..871611a891351 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/SHA1.h"
@@ -906,19 +907,22 @@ void GreedySLPVectorizerPass::runOnOperation() {
bool changed;
do {
changed = false;
- // Walk all blocks recursively
- if (op->walk([&](Block *block) -> WalkResult {
- FailureOr<size_t> numNodesVectorized =
- tryToVectorizeInBlock(*block);
- if (failed(numNodesVectorized))
- return WalkResult::interrupt();
-
- changed = changed || *numNodesVectorized > 0;
+ auto visitor = [&](Block *block) -> WalkResult {
+ FailureOr<size_t> numNodesVectorized = tryToVectorizeInBlock(*block);
+ if (failed(numNodesVectorized))
+ return WalkResult::interrupt();
- return WalkResult::advance();
- }).wasInterrupted())
+ changed = changed || *numNodesVectorized > 0;
+ return WalkResult::advance();
+ };
+ // Walk all blocks recursively
+ if (op->walk(visitor).wasInterrupted())
return signalPassFailure();
+ // Run empty `applyPatternsGreedily` for simple DCE and folding.
+ if (changed)
+ (void)applyPatternsGreedily(
+ op, {}, GreedyRewriteConfig().enableFolding().enableConstantCSE());
} while (changed);
}
>From 6746318fb79cebe27611e1fd6b566c9562f2e353 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 15:37:57 +0200
Subject: [PATCH 34/52] comment
---
mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 871611a891351..ce6b088e0e07f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -715,9 +715,13 @@ struct OperationsFingerprint {
DenseMap<Operation *, Fingerprint> fingerprints;
};
+/// Check if two ops are equivalent for the purposes of SLP vectorization, i.e.
+/// they can be merged into single vector op.
static bool isEquivalent(Operation *op1, Operation *op2) {
assert(op1 && "null op1");
assert(op2 && "null op2");
+ if (op1 == op2)
+ return true;
if (op1->getName() != op2->getName())
return false;
>From e650bbe4797b92058b7172d64b3a47d0153efc1f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 15:45:52 +0200
Subject: [PATCH 35/52] test
---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 75 ++++++++++++++++++++-
1 file changed, 72 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 517b2318f773d..cbcd553d90f0a 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,6 +1,20 @@
// RUN: mlir-opt %s --greedy-slp-vectorizer | FileCheck %s
+// CHECK-LABEL: func @negative_single_op
+func.func @negative_single_op(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK-NOT: vector
+ %c0 = arith.constant 0 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xi32>
+ %4 = memref.load %arg1[%c0] : memref<8xi32>
+ %8 = arith.addi %0, %4 : i32
+ memref.store %8, %arg0[%c0] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
@@ -554,15 +568,70 @@ func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1:
}
-// CHECK-LABEL: func @negative_single_op
-func.func @negative_single_op(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
- // CHECK-NOT: vector
+// CHECK-LABEL: func @negative_different_blocks
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @negative_different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V1:.*]] = vector.extract %[[V0]][2] : i32 from vector<4xi32>
+ // CHECK: %[[V2:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
+ // CHECK: %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V4:.*]] = vector.extract %[[V3]][2] : i32 from vector<4xi32>
+ // CHECK: %[[V5:.*]] = vector.extract %[[V3]][3] : i32 from vector<4xi32>
+ // CHECK: cf.br ^bb1
+ // CHECK: ^bb1:
+ // CHECK: %[[V6:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V7:.*]] = vector.extract_strided_slice %[[V3]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] : vector<2xi32>
+ // CHECK: %[[V9:.*]] = vector.extract %[[V8]][0] : i32 from vector<2xi32>
+ // CHECK: %[[V10:.*]] = vector.extract %[[V8]][1] : i32 from vector<2xi32>
+ // CHECK: cf.br ^bb2
+ // CHECK: ^bb2:
+ // TODO: we need to properly handle vector.extract args to vectorizre that
+ // CHECK: %[[V11:.*]] = arith.addi %[[V1]], %[[V4]] : i32
+ // CHECK: %[[V12:.*]] = arith.addi %[[V2]], %[[V5]] : i32
+ // CHECK: cf.br ^bb3
+ // CHECK: ^bb3:
+ // CHECK: memref.store %[[V9]], %[[ARG0]][%[[C0]]] : memref<8xi32>
+ // CHECK: memref.store %[[V10]], %[[ARG0]][%[[C1]]] : memref<8xi32>
+ // CHECK: memref.store %[[V11]], %[[ARG0]][%[[C2]]] : memref<8xi32>
+ // CHECK: memref.store %[[V12]], %[[ARG0]][%[[C3]]] : memref<8xi32>
+
%c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
%0 = memref.load %arg0[%c0] : memref<8xi32>
+ %1 = memref.load %arg0[%c1] : memref<8xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = memref.load %arg0[%c3] : memref<8xi32>
+
%4 = memref.load %arg1[%c0] : memref<8xi32>
+ %5 = memref.load %arg1[%c1] : memref<8xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = memref.load %arg1[%c3] : memref<8xi32>
+
+ cf.br ^bb0
+
+^bb0:
%8 = arith.addi %0, %4 : i32
+ %9 = arith.addi %1, %5 : i32
+ cf.br ^bb1
+
+^bb1:
+ %10 = arith.addi %2, %6 : i32
+ %11 = arith.addi %3, %7 : i32
+ cf.br ^bb2
+
+^bb2:
memref.store %8, %arg0[%c0] : memref<8xi32>
+ memref.store %9, %arg0[%c1] : memref<8xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ memref.store %11, %arg0[%c3] : memref<8xi32>
return
}
>From 08f362e04c81cf8e64ab6379d5bc8595e379a465 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 15:47:12 +0200
Subject: [PATCH 36/52] cleanup
---
mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ce6b088e0e07f..d73f35cd2599a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -924,9 +924,10 @@ void GreedySLPVectorizerPass::runOnOperation() {
return signalPassFailure();
// Run empty `applyPatternsGreedily` for simple DCE and folding.
- if (changed)
- (void)applyPatternsGreedily(
- op, {}, GreedyRewriteConfig().enableFolding().enableConstantCSE());
+ if (changed) {
+ auto config = GreedyRewriteConfig().enableFolding().enableConstantCSE();
+ (void)applyPatternsGreedily(op, {}, config);
+ }
} while (changed);
}
>From 4b44d61b6be1af15b1d6d97975a94f6b38600f79 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 17:50:52 +0200
Subject: [PATCH 37/52] process extract ops
---
.../Vector/Transforms/SLPVectorizer.cpp | 79 ++++++++++++++++---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 74 +++++++----------
2 files changed, 99 insertions(+), 54 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index d73f35cd2599a..bdfaa72a36914 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -251,7 +251,19 @@ extractContiguousGroups(const MemoryOpGroup &group) {
}
static bool isVectorizable(Operation *op) {
- return OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1;
+ if (!OpTrait::hasElementwiseMappableTraits(op))
+ return false;
+
+ if (op->getNumResults() != 1)
+ return false;
+
+ for (auto type :
+ llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
+ if (!type.isIntOrIndexOrFloat())
+ return false;
+ }
+
+ return true;
}
/// A node in the SLP graph representing a group of vectorizable operations
@@ -419,6 +431,12 @@ class SLPGraph {
IRMapping mapping;
for (auto *node : sortedNodes) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Processing node with " << node->size()
+ << " operations\n";
+ llvm::dbgs() << " First op: " << *node->op() << "\n";
+ });
+
// `op` is the node with the smallest index in vector and not the
// nessesarily the good insertion point.
Operation *op = node->op();
@@ -500,6 +518,9 @@ class SLPGraph {
mapping.map(op->getResults(), newOp->getResults());
handleNonVectorOutputs(newOp->getResult(0));
+ } else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
+ Value val = handleVecSizeMismatch(extract.getVector());
+ mapping.map(extract.getResult(), val);
} else {
op->emitError("unsupported operation");
return failure();
@@ -735,6 +756,14 @@ static bool isEquivalent(Operation *op1, Operation *op2) {
return true;
}
+/// Get static position of the extract op, if it is 1D and static.
+static std::optional<int64_t> getExtractIndex(vector::ExtractOp extractOp) {
+ if (extractOp.getNumIndices() != 1 || extractOp.hasDynamicPosition())
+ return std::nullopt;
+
+ return extractOp.getStaticPosition().front();
+}
+
/// Build the SLP graph starting from memory operation groups and going both
/// top-down and bottom-up through uses/operands respectively.
static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
@@ -824,17 +853,47 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
return;
}
- if (!isVectorizable(srcOp))
- return;
-
SmallVector<Operation *> currentOps;
- currentOps.emplace_back(srcOp);
- for (Operation *op : ArrayRef(node->ops).drop_front()) {
- Operation *otherOp = op->getOperand(index).getDefiningOp();
- if (!otherOp || !isEquivalent(otherOp, srcOp))
- break;
+ if (auto extractOp = dyn_cast<vector::ExtractOp>(srcOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " Processing vector.extract op with index "
+ << getExtractIndex(extractOp).value_or(-1) << "\n");
+ currentOps.push_back(extractOp);
+
+ std::optional<int64_t> extractIndex = getExtractIndex(extractOp);
+ if (!extractIndex)
+ return;
+
+ Value vector = extractOp.getVector();
+ int64_t currentIndex = *extractIndex;
+ for (Operation *op : ArrayRef(node->ops).drop_front()) {
+ auto otherOp = op->getOperand(index).getDefiningOp<vector::ExtractOp>();
+ if (!otherOp || otherOp.getVector() != vector)
+ break;
+
+ std::optional<int64_t> otherExtractIndex = getExtractIndex(otherOp);
+ if (!otherExtractIndex || *otherExtractIndex != (currentIndex + 1))
+ break;
+
+ currentOps.push_back(otherOp);
+ ++currentIndex;
+ }
+ } else if (isVectorizable(srcOp)) {
+ LLVM_DEBUG(llvm::dbgs() << " Processing vectorizable op "
+ << srcOp->getName() << "\n");
+
+ currentOps.emplace_back(srcOp);
+ for (Operation *op : ArrayRef(node->ops).drop_front()) {
+ Operation *otherOp = op->getOperand(index).getDefiningOp();
+ if (!otherOp || !isEquivalent(otherOp, srcOp))
+ break;
- currentOps.push_back(otherOp);
+ currentOps.push_back(otherOp);
+ }
+ } else {
+ LLVM_DEBUG(llvm::dbgs()
+ << " Unsupported op " << srcOp->getName() << "\n");
+ return;
}
if (currentOps.size() == 1)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index cbcd553d90f0a..b27a72c6a8fe7 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -283,22 +283,18 @@ func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref
// CHECK-LABEL: func @read_read_add_write_attrs_mismatch
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add_write_attrs_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
- // CHECK: %[[V1:.*]] = vector.extract %[[V0]][2] : i32 from vector<4xi32>
- // CHECK: %[[V2:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
- // CHECK: %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
- // CHECK: %[[V4:.*]] = vector.extract %[[V3]][2] : i32 from vector<4xi32>
- // CHECK: %[[V5:.*]] = vector.extract %[[V3]][3] : i32 from vector<4xi32>
- // CHECK: %[[V6:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
- // CHECK: %[[V7:.*]] = vector.extract_strided_slice %[[V3]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
- // CHECK: %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] overflow<nsw> : vector<2xi32>
- // CHECK: %[[V9:.*]] = vector.from_elements %[[V1]], %[[V2]] : vector<2xi32>
- // CHECK: %[[V10:.*]] = vector.from_elements %[[V4]], %[[V5]] : vector<2xi32>
- // CHECK: %[[V11:.*]] = arith.addi %[[V9]], %[[V10]] overflow<nuw> : vector<2xi32>
- // CHECK: vector.store %[[V8]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
- // CHECK: vector.store %[[V11]], %[[ARG0]][%[[C2]]] : memref<8xi32>, vector<2xi32>
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V5:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V6:.*]] = arith.addi %[[V4]], %[[V5]] overflow<nsw> : vector<2xi32>
+ // CHECK: %[[V7:.*]] = arith.addi %[[V1]], %[[V3]] overflow<nuw> : vector<2xi32>
+ // CHECK: vector.store %[[V6]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: vector.store %[[V7]], %[[ARG0]][%[[C2]]] : memref<8xi32>, vector<2xi32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -571,34 +567,24 @@ func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1:
// CHECK-LABEL: func @negative_different_blocks
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @negative_different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
- // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
- // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
- // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
- // CHECK: %[[V1:.*]] = vector.extract %[[V0]][2] : i32 from vector<4xi32>
- // CHECK: %[[V2:.*]] = vector.extract %[[V0]][3] : i32 from vector<4xi32>
- // CHECK: %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
- // CHECK: %[[V4:.*]] = vector.extract %[[V3]][2] : i32 from vector<4xi32>
- // CHECK: %[[V5:.*]] = vector.extract %[[V3]][3] : i32 from vector<4xi32>
- // CHECK: cf.br ^bb1
- // CHECK: ^bb1:
- // CHECK: %[[V6:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
- // CHECK: %[[V7:.*]] = vector.extract_strided_slice %[[V3]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
- // CHECK: %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] : vector<2xi32>
- // CHECK: %[[V9:.*]] = vector.extract %[[V8]][0] : i32 from vector<2xi32>
- // CHECK: %[[V10:.*]] = vector.extract %[[V8]][1] : i32 from vector<2xi32>
- // CHECK: cf.br ^bb2
- // CHECK: ^bb2:
- // TODO: we need to properly handle vector.extract args to vectorizre that
- // CHECK: %[[V11:.*]] = arith.addi %[[V1]], %[[V4]] : i32
- // CHECK: %[[V12:.*]] = arith.addi %[[V2]], %[[V5]] : i32
- // CHECK: cf.br ^bb3
- // CHECK: ^bb3:
- // CHECK: memref.store %[[V9]], %[[ARG0]][%[[C0]]] : memref<8xi32>
- // CHECK: memref.store %[[V10]], %[[ARG0]][%[[C1]]] : memref<8xi32>
- // CHECK: memref.store %[[V11]], %[[ARG0]][%[[C2]]] : memref<8xi32>
- // CHECK: memref.store %[[V12]], %[[ARG0]][%[[C3]]] : memref<8xi32>
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: cf.br ^bb1
+ // CHECK: ^bb1:
+ // CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V5:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V6:.*]] = arith.addi %[[V4]], %[[V5]] : vector<2xi32>
+ // CHECK: cf.br ^bb2
+ // CHECK: ^bb2:
+ // CHECK: %[[V7:.*]] = arith.addi %[[V1]], %[[V3]] : vector<2xi32>
+ // CHECK: cf.br ^bb3
+ // CHECK: ^bb3:
+ // CHECK: vector.store %[[V6]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: vector.store %[[V7]], %[[ARG0]][%[[C2]]] : memref<8xi32>, vector<2xi32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
>From 293b27a2b1c32ad4e36f4923b7d720c92fec2a20 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 19:53:47 +0200
Subject: [PATCH 38/52] handle vec size and domination
---
.../Vector/Transforms/SLPVectorizer.cpp | 409 +++++++++++-------
mlir/test/Dialect/Vector/slp-vectorize.mlir | 22 +-
2 files changed, 267 insertions(+), 164 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index bdfaa72a36914..e7c550b64e71f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -11,11 +11,13 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
@@ -266,6 +268,13 @@ static bool isVectorizable(Operation *op) {
return true;
}
+/// Get the next operation in the block, assuming `op` is not a terminator.
+static Operation *nextOp(Operation *op) {
+ assert(op && "null op");
+ auto it = op->getIterator();
+ return &*std::next(it);
+}
+
/// A node in the SLP graph representing a group of vectorizable operations
struct SLPGraphNode {
SmallVector<Operation *> ops;
@@ -293,6 +302,31 @@ struct SLPGraphNode {
if (op->isBeforeInBlock(ret))
ret = op;
}
+
+ for (Operation *op : ops) {
+ for (Value opOperand : op->getOperands()) {
+ Operation *defOp = opOperand.getDefiningOp();
+ if (!defOp || defOp->getBlock() != ret->getBlock())
+ continue;
+
+ Operation *next = nextOp(defOp);
+ if (ret->isBeforeInBlock(next))
+ ret = next;
+ }
+ }
+
+ // Try to adjust insertion point to satisfy dominance relations with
+ // operands.
+ for (SLPGraphNode *operand : operands) {
+ Operation *ip = operand->getInsertionPoint();
+ if (!ip)
+ return nullptr;
+
+ Operation *next = nextOp(ip);
+ if (next->getBlock() == ret->getBlock() && ret->isBeforeInBlock(next))
+ ret = next;
+ }
+
return ret;
}
};
@@ -387,159 +421,9 @@ class SLPGraph {
/// Vectorize the operations in the graph.
/// Returns number of nodes vectorized or failure if failed.
- FailureOr<size_t> vectorize(IRRewriter &rewriter) {
- if (nodes.empty())
- return 0;
-
- LLVM_DEBUG(llvm::dbgs()
- << "Vectorizing SLP graph with " << nodes.size() << " nodes\n");
-
- // Get topologically sorted nodes
- SmallVector<SLPGraphNode *> sortedNodes = topologicalSort();
- if (sortedNodes.empty()) {
- LLVM_DEBUG(llvm::dbgs() << "Failed to topologically sort nodes\n");
- return failure();
- }
-
- LLVM_DEBUG({
- llvm::dbgs() << "Topologically sorted nodes:\n";
- for (auto *node : sortedNodes) {
- llvm::dbgs() << " Node with " << node->size()
- << " operations: " << node->op()->getName() << "\n";
- }
- });
-
- auto isBadNode = [&](SLPGraphNode *node) {
- // Do not vectorize stray nodes which are not connected to any other
- // nodes.
- return node->users.empty() && node->operands.empty();
- };
-
- // Update node vec sizes if its inputs vec sizes are smaller.
- // This is nedeed to handle situations when we have 3->3->4 sizes in tree.
- // TODO: It maybe possible to reconstruct the larger vec size combining src
- // smaller vector and scalar arg.
- for (auto *node : sortedNodes) {
- size_t size = node->size();
- for (auto *operand : node->operands)
- size = std::min(size, operand->size());
-
- node->ops.resize(size);
- }
-
- llvm::erase_if(sortedNodes, isBadNode);
-
- IRMapping mapping;
- for (auto *node : sortedNodes) {
- LLVM_DEBUG({
- llvm::dbgs() << "Processing node with " << node->size()
- << " operations\n";
- llvm::dbgs() << " First op: " << *node->op() << "\n";
- });
-
- // `op` is the node with the smallest index in vector and not the
- // nessesarily the good insertion point.
- Operation *op = node->op();
- Operation *ip = node->getInsertionPoint();
- if (!ip)
- return op->emitError("no insertion point found for node");
-
- rewriter.setInsertionPoint(ip);
- int64_t numElements = node->size();
- Location loc = op->getLoc();
-
- auto handleNonVectorInputs = [&](ValueRange operands) {
- for (auto [i, operand] : llvm::enumerate(operands)) {
- if (getNodeForOp(operand.getDefiningOp()))
- continue;
-
- SmallVector<Value> args;
- for (Operation *defOp : node->ops)
- args.push_back(defOp->getOperand(i));
-
- auto vecType = VectorType::get(numElements, operand.getType());
- Value vector =
- rewriter.create<vector::FromElementsOp>(loc, vecType, args);
- mapping.map(operand, vector);
- }
- };
-
- auto handleNonVectorOutputs = [&](Value newResult) {
- for (auto [i, result] : llvm::enumerate(node->ops)) {
- for (OpOperand &use : result->getUses()) {
- Operation *useOwner = use.getOwner();
- if (getNodeForOp(useOwner))
- continue;
-
- Value elem = rewriter.create<vector::ExtractOp>(loc, newResult, i);
- use.set(elem);
- }
- }
- };
-
- auto handleVecSizeMismatch = [&](Value arg) -> Value {
- auto srcType = cast<VectorType>(arg.getType());
- assert(srcType.getRank() == 1);
- if (srcType.getDimSize(0) == numElements)
- return arg;
-
- return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, 0,
- numElements, 1);
- };
-
- if (auto load = dyn_cast<memref::LoadOp>(op)) {
- auto vecType =
- VectorType::get(numElements, load.getMemRefType().getElementType());
- Value result = rewriter.create<vector::LoadOp>(
- loc, vecType, load.getMemRef(), load.getIndices());
- mapping.map(load.getResult(), result);
- handleNonVectorOutputs(result);
- } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
- handleNonVectorInputs(store.getValueToStore());
- Value val = mapping.lookupOrDefault(store.getValueToStore());
- val = handleVecSizeMismatch(val);
- rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
- store.getIndices());
- } else if (isVectorizable(op)) {
- handleNonVectorInputs(op->getOperands());
- Operation *newOp = rewriter.clone(*op, mapping);
- auto resVectorType =
- VectorType::get(numElements, op->getResultTypes().front());
-
- {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(newOp);
- for (OpOperand &operand : newOp->getOpOperands()) {
- Value newOperand = handleVecSizeMismatch(operand.get());
- operand.set(newOperand);
- }
- }
- newOp->getResult(0).setType(resVectorType);
-
- mapping.map(op->getResults(), newOp->getResults());
- handleNonVectorOutputs(newOp->getResult(0));
- } else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
- Value val = handleVecSizeMismatch(extract.getVector());
- mapping.map(extract.getResult(), val);
- } else {
- op->emitError("unsupported operation");
- return failure();
- }
- }
-
- LLVM_DEBUG(llvm::dbgs() << "Erasing original ops\n");
-
- // As all nodes were cloned, we need to erase the original ops in reverse
- // topo order to avoid invalidation users.
- for (auto *node : llvm::reverse(sortedNodes)) {
- for (Operation *op : node->ops) {
- rewriter.eraseOp(op);
- }
- }
-
- LLVM_DEBUG(llvm::dbgs() << "Vectorization completed successfully\n");
- return sortedNodes.size();
- }
+ FailureOr<size_t>
+ vectorize(IRRewriter &rewriter,
+ llvm::function_ref<bool(Type, size_t)> isValidVecType);
/// Print the graph structure
[[maybe_unused]] void print() const {
@@ -736,6 +620,31 @@ struct OperationsFingerprint {
DenseMap<Operation *, Fingerprint> fingerprints;
};
+/// Check if op input/output types can be vectorized.
+static bool
+checkOpVecType(SLPGraphNode *node,
+ llvm::function_ref<bool(Type, size_t)> isValidVecType) {
+ Operation *op = node->op();
+ size_t size = node->size();
+ if (Type elementType = getElementType(op))
+ return isValidVecType(elementType, size);
+
+ if (isVectorizable(op)) {
+ for (auto type :
+ llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
+ if (!isValidVecType(type, size))
+ return false;
+ }
+ return true;
+ }
+
+ if (auto extract = dyn_cast<vector::ExtractOp>(op))
+ return isValidVecType(extract.getResult().getType(), size);
+
+ LLVM_DEBUG(llvm::dbgs() << "Unsupported op " << op->getName() << "\n");
+ return false;
+}
+
/// Check if two ops are equivalent for the purposes of SLP vectorization, i.e.
/// they can be merged into single vector op.
static bool isEquivalent(Operation *op1, Operation *op2) {
@@ -924,9 +833,176 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
return graph;
}
+FailureOr<size_t>
+SLPGraph::vectorize(IRRewriter &rewriter,
+ llvm::function_ref<bool(Type, size_t)> isValidVecType) {
+ if (nodes.empty())
+ return 0;
+
+ LLVM_DEBUG(llvm::dbgs() << "Vectorizing SLP graph with " << nodes.size()
+ << " nodes\n");
+
+ // Get topologically sorted nodes
+ SmallVector<SLPGraphNode *> sortedNodes = topologicalSort();
+ if (sortedNodes.empty()) {
+ LLVM_DEBUG(llvm::dbgs() << "Failed to topologically sort nodes\n");
+ return failure();
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "Topologically sorted nodes:\n";
+ for (auto *node : sortedNodes) {
+ llvm::dbgs() << " Node with " << node->size()
+ << " operations: " << node->op()->getName() << "\n";
+ }
+ });
+
+ auto isBadNode = [&](SLPGraphNode *node) {
+ // Do not vectorize stray nodes which are not connected to any other
+ // nodes.
+ return (node->users.empty() && node->operands.empty()) || node->size() <= 1;
+ };
+
+ // Update node vec sizes if its inputs vec sizes are smaller.
+ // This is nedeed to handle situations when we have 3->3->4 sizes in tree.
+ // TODO: It maybe possible to reconstruct the larger vec size combining src
+ // smaller vector and scalar arg.
+ for (auto *node : sortedNodes) {
+ size_t size = node->size();
+ for (auto *operand : node->operands)
+ size = std::min(size, operand->size());
+
+ node->ops.resize(size);
+
+ while (node->size() > 1) {
+ if (checkOpVecType(node, isValidVecType))
+ break;
+
+ node->ops.pop_back();
+ }
+ }
+
+ llvm::erase_if(sortedNodes, isBadNode);
+
+ IRMapping mapping;
+ for (auto *node : sortedNodes) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Processing node with " << node->size()
+ << " operations\n";
+ llvm::dbgs() << " First op: " << *node->op() << "\n";
+ });
+
+ // `op` is the node with the smallest index in vector and not the
+ // nessesarily the good insertion point.
+ Operation *op = node->op();
+ Operation *ip = node->getInsertionPoint();
+ if (!ip)
+ return op->emitError("no insertion point found for node");
+
+ LLVM_DEBUG(llvm::dbgs() << " Insertion point: " << *ip << "\n");
+
+ rewriter.setInsertionPoint(ip);
+ int64_t numElements = node->size();
+ Location loc = op->getLoc();
+
+ auto handleNonVectorInputs = [&](ValueRange operands) {
+ for (auto [i, operand] : llvm::enumerate(operands)) {
+ if (getNodeForOp(operand.getDefiningOp()))
+ continue;
+
+ SmallVector<Value> args;
+ for (Operation *defOp : node->ops)
+ args.push_back(defOp->getOperand(i));
+
+ auto vecType = VectorType::get(numElements, operand.getType());
+ Value vector =
+ rewriter.create<vector::FromElementsOp>(loc, vecType, args);
+ mapping.map(operand, vector);
+ }
+ };
+
+ auto handleNonVectorOutputs = [&](Value newResult) {
+ for (auto [i, result] : llvm::enumerate(node->ops)) {
+ for (OpOperand &use : result->getUses()) {
+ Operation *useOwner = use.getOwner();
+ if (getNodeForOp(useOwner))
+ continue;
+
+ Value elem = rewriter.create<vector::ExtractOp>(loc, newResult, i);
+ use.set(elem);
+ }
+ }
+ };
+
+ auto handleVecSizeMismatch = [&](Value arg) -> Value {
+ auto srcType = cast<VectorType>(arg.getType());
+ assert(srcType.getRank() == 1);
+ if (srcType.getDimSize(0) == numElements)
+ return arg;
+
+ return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, 0,
+ numElements, 1);
+ };
+
+ if (auto load = dyn_cast<memref::LoadOp>(op)) {
+ auto vecType =
+ VectorType::get(numElements, load.getMemRefType().getElementType());
+ Value result = rewriter.create<vector::LoadOp>(
+ loc, vecType, load.getMemRef(), load.getIndices());
+ mapping.map(load.getResult(), result);
+ handleNonVectorOutputs(result);
+ } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
+ handleNonVectorInputs(store.getValueToStore());
+ Value val = mapping.lookupOrDefault(store.getValueToStore());
+ val = handleVecSizeMismatch(val);
+ rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
+ store.getIndices());
+ } else if (isVectorizable(op)) {
+ handleNonVectorInputs(op->getOperands());
+ Operation *newOp = rewriter.clone(*op, mapping);
+ auto resVectorType =
+ VectorType::get(numElements, op->getResultTypes().front());
+
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(newOp);
+ for (OpOperand &operand : newOp->getOpOperands()) {
+ Value newOperand = handleVecSizeMismatch(operand.get());
+ operand.set(newOperand);
+ }
+ }
+ newOp->getResult(0).setType(resVectorType);
+
+ mapping.map(op->getResults(), newOp->getResults());
+ handleNonVectorOutputs(newOp->getResult(0));
+ } else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
+ Value val = handleVecSizeMismatch(extract.getVector());
+ mapping.map(extract.getResult(), val);
+ } else {
+ op->emitError("unsupported operation");
+ return failure();
+ }
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "Erasing original ops\n");
+
+ // As all nodes were cloned, we need to erase the original ops in reverse
+ // topo order to avoid invalidation users.
+ for (auto *node : llvm::reverse(sortedNodes)) {
+ for (Operation *op : node->ops) {
+ rewriter.eraseOp(op);
+ }
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "Vectorization completed successfully\n");
+ return sortedNodes.size();
+}
+
/// Try to vectorize ops in a block.
/// Returns number of nodes vectorized or error flag if failed.
-static FailureOr<size_t> tryToVectorizeInBlock(Block &block) {
+static FailureOr<size_t>
+tryToVectorizeInBlock(Block &block,
+ llvm::function_ref<bool(Type, size_t)> isValidVecType) {
LLVM_DEBUG(llvm::dbgs() << "Processing block in operation: "
<< block.getParentOp()->getName() << "\n");
@@ -956,22 +1032,42 @@ static FailureOr<size_t> tryToVectorizeInBlock(Block &block) {
// Vectorize the graph
IRRewriter rewriter(block.getParentOp()->getContext());
- FailureOr<size_t> numNodesVectorized = graph.vectorize(rewriter);
+ FailureOr<size_t> numNodesVectorized =
+ graph.vectorize(rewriter, isValidVecType);
if (failed(numNodesVectorized))
LLVM_DEBUG(llvm::dbgs() << "Failed to vectorize graph\n");
return numNodesVectorized;
}
+static bool isPow2(size_t size) {
+ assert(size > 0);
+ return (size & (size - 1)) == 0;
+}
+
void GreedySLPVectorizerPass::runOnOperation() {
Operation *op = getOperation();
+ const DataLayout *dataLayout = nullptr;
+ auto isValidVecType = [&](Type type, size_t count) {
+ if (!isPow2(count))
+ return false;
+
+ if (!dataLayout)
+ dataLayout = &getAnalysis<DataLayoutAnalysis>().getAtOrAbove(op);
+
+ auto sizeInBits = dataLayout->getTypeSizeInBits(type);
+
+ return sizeInBits * count <= 256;
+ };
+
// Run until fixed point is reached.
bool changed;
do {
changed = false;
auto visitor = [&](Block *block) -> WalkResult {
- FailureOr<size_t> numNodesVectorized = tryToVectorizeInBlock(*block);
+ FailureOr<size_t> numNodesVectorized =
+ tryToVectorizeInBlock(*block, isValidVecType);
if (failed(numNodesVectorized))
return WalkResult::interrupt();
@@ -987,6 +1083,7 @@ void GreedySLPVectorizerPass::runOnOperation() {
auto config = GreedyRewriteConfig().enableFolding().enableConstantCSE();
(void)applyPatternsGreedily(op, {}, config);
}
+ op->dump();
} while (changed);
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index b27a72c6a8fe7..c363fe9491ee3 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -479,16 +479,22 @@ func.func private @use(i32)
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C3:.*]] = arith.constant 3 : index
- // CHECK: %[[V0:.*]] = memref.load %[[ARG0]][%[[C3]]] : memref<8xi32>
- // CHECK: %[[V1:.*]] = memref.load %[[ARG1]][%[[C3]]] : memref<8xi32>
+ // CHECK: %[[V0:.*]] = memref.load %arg0[%[[C3]]] : memref<8xi32>
+ // CHECK: %[[V1:.*]] = memref.load %arg1[%[[C3]]] : memref<8xi32>
// CHECK: call @use(%[[V0]]) : (i32) -> ()
- // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : i32
- // CHECK: %[[V3:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
- // CHECK: %[[V4:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<3xi32>
- // CHECK: %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : vector<3xi32>
- // CHECK: vector.store %[[V5]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<3xi32>
- // CHECK: memref.store %[[V2]], %[[ARG0]][%[[C3]]] : memref<8xi32>
+ // CHECK: %[[V2:.*]] = vector.load %arg0[%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: %[[V3:.*]] = vector.load %arg1[%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: %[[V4:.*]] = memref.load %arg0[%[[C2]]] : memref<8xi32>
+ // CHECK: %[[V5:.*]] = memref.load %arg1[%[[C2]]] : memref<8xi32>
+ // CHECK: %[[V6:.*]] = vector.from_elements %[[V4]], %[[V0]] : vector<2xi32>
+ // CHECK: %[[V7:.*]] = vector.from_elements %[[V5]], %[[V1]] : vector<2xi32>
+ // CHECK: %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] : vector<2xi32>
+ // CHECK: %[[V9:.*]] = arith.addi %[[V2]], %[[V3]] : vector<2xi32>
+ // CHECK: vector.store %[[V9]], %arg0[%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: vector.store %[[V8]], %arg0[%[[C2]]] : memref<8xi32>, vector<2xi32>
+
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
>From 8db555609147d308d14ac317fb9b20ba4ad11828 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 19:59:51 +0200
Subject: [PATCH 39/52] cache insertion point
---
mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e7c550b64e71f..ced04e406edbc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -280,6 +280,7 @@ struct SLPGraphNode {
SmallVector<Operation *> ops;
SmallVector<SLPGraphNode *> users;
SmallVector<SLPGraphNode *> operands;
+ Operation *insertionPoint = nullptr;
bool isRoot = false;
SLPGraphNode() = default;
@@ -293,10 +294,13 @@ struct SLPGraphNode {
return ops.front();
}
- Operation *getInsertionPoint() const {
+ Operation *getInsertionPoint() {
+ assert(!ops.empty() && "empty node");
+ if (insertionPoint)
+ return insertionPoint;
+
// Find the toplogically first node, which is not nessesary the first in the
// `ops` as `ops` are sorted by their position in vector.
- assert(!ops.empty() && "empty node");
Operation *ret = op();
for (Operation *op : ArrayRef(ops).drop_front()) {
if (op->isBeforeInBlock(ret))
@@ -327,6 +331,7 @@ struct SLPGraphNode {
ret = next;
}
+ insertionPoint = ret;
return ret;
}
};
>From e4c1589f545dd8f2a0b85aa536b95c2464d16dfb Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 20:33:56 +0200
Subject: [PATCH 40/52] test
---
.../Vector/Transforms/SLPVectorizer.cpp | 2 +-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 63 +++++++++++++++++++
2 files changed, 64 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index ced04e406edbc..a73da4a93dec3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -995,6 +995,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
// topo order to avoid invalidation users.
for (auto *node : llvm::reverse(sortedNodes)) {
for (Operation *op : node->ops) {
+ LLVM_DEBUG(llvm::dbgs() << "Erasing op: " << *op << "\n");
rewriter.eraseOp(op);
}
}
@@ -1088,7 +1089,6 @@ void GreedySLPVectorizerPass::runOnOperation() {
auto config = GreedyRewriteConfig().enableFolding().enableConstantCSE();
(void)applyPatternsGreedily(op, {}, config);
}
- op->dump();
} while (changed);
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index c363fe9491ee3..0a40037b015b1 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -245,6 +245,69 @@ func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
}
+// CHECK-LABEL: func @read_read_add_write_seven
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xindex>, %[[ARG1:.*]]: memref<8xindex>)
+func.func @read_read_add_write_seven(%arg0: memref<8xindex>, %arg1: memref<8xindex>) {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+ // CHECK: %[[A0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xindex>, vector<4xindex>
+ // CHECK: %[[A1:.*]] = vector.load %[[ARG0]][%[[C4]]] : memref<8xindex>, vector<2xindex>
+ // CHECK: %[[A2:.*]] = memref.load %[[ARG0]][%[[C6]]] : memref<8xindex>
+ // CHECK: %[[B0:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xindex>, vector<4xindex>
+ // CHECK: %[[B1:.*]] = vector.load %[[ARG1]][%[[C4]]] : memref<8xindex>, vector<2xindex>
+ // CHECK: %[[B2:.*]] = memref.load %[[ARG1]][%[[C6]]] : memref<8xindex>
+ // CHECK: %[[RES0:.*]] = arith.addi %[[A0]], %[[B0]] : vector<4xindex>
+ // CHECK: %[[RES1:.*]] = arith.addi %[[A1]], %[[B1]] : vector<2xindex>
+ // CHECK: %[[RES2:.*]] = arith.addi %[[A2]], %[[B2]] : index
+ // CHECK: vector.store %[[RES0]], %[[ARG0]][%[[C0]]] : memref<8xindex>, vector<4xindex>
+ // CHECK: vector.store %[[RES1]], %[[ARG0]][%[[C4]]] : memref<8xindex>, vector<2xindex>
+ // CHECK: memref.store %[[RES2]], %[[ARG0]][%[[C6]]] : memref<8xindex>
+
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %c5 = arith.constant 5 : index
+ %c6 = arith.constant 6 : index
+
+ %0 = memref.load %arg0[%c0] : memref<8xindex>
+ %1 = memref.load %arg0[%c1] : memref<8xindex>
+ %2 = memref.load %arg0[%c2] : memref<8xindex>
+ %3 = memref.load %arg0[%c3] : memref<8xindex>
+ %4 = memref.load %arg0[%c4] : memref<8xindex>
+ %5 = memref.load %arg0[%c5] : memref<8xindex>
+ %6 = memref.load %arg0[%c6] : memref<8xindex>
+
+ %7 = memref.load %arg1[%c0] : memref<8xindex>
+ %8 = memref.load %arg1[%c1] : memref<8xindex>
+ %9 = memref.load %arg1[%c2] : memref<8xindex>
+ %10 = memref.load %arg1[%c3] : memref<8xindex>
+ %11 = memref.load %arg1[%c4] : memref<8xindex>
+ %12 = memref.load %arg1[%c5] : memref<8xindex>
+ %13 = memref.load %arg1[%c6] : memref<8xindex>
+
+ %14 = arith.addi %0, %7 : index
+ %15 = arith.addi %1, %8 : index
+ %16 = arith.addi %2, %9 : index
+ %17 = arith.addi %3, %10 : index
+ %18 = arith.addi %4, %11 : index
+ %19 = arith.addi %5, %12 : index
+ %20 = arith.addi %6, %13 : index
+
+ memref.store %14, %arg0[%c0] : memref<8xindex>
+ memref.store %15, %arg0[%c1] : memref<8xindex>
+ memref.store %16, %arg0[%c2] : memref<8xindex>
+ memref.store %17, %arg0[%c3] : memref<8xindex>
+ memref.store %18, %arg0[%c4] : memref<8xindex>
+ memref.store %19, %arg0[%c5] : memref<8xindex>
+ memref.store %20, %arg0[%c6] : memref<8xindex>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add_write_size_mismatch
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add_write_size_mismatch(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
>From fb190bf6e136c85dee1b748640fc405bb5b34f6b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 20:47:44 +0200
Subject: [PATCH 41/52] pass option
---
mlir/include/mlir/Dialect/Vector/Transforms/Passes.td | 10 ++++++----
mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 4 +++-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 2 +-
3 files changed, 10 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index d5c31c9f78409..970e488d3494d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -44,12 +44,14 @@ def GreedySLPVectorizer : Pass<"greedy-slp-vectorizer"> {
This is greedy vectorizer, it doesn't have any cost model (yet) and it tries
to create vector ops if we have at least 2 potential ops.
-
- It doesn't check if target actually supports resulted vectors either, user
- will need a follow up pass which will split large and/or unaliggned vectors
- into sizes actually supported by the target.
}];
let dependentDialects = ["mlir::vector::VectorDialect"];
+
+ let options = [
+ Option<"maxVectorBitwidth", "max-vector-bitwidth", "unsigned",
+ /*default=*/"std::numeric_limits<unsigned>::max()",
+ "Maximum supported vector bitwidth">,
+ ];
}
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index a73da4a93dec3..0e650359d3339 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -515,6 +515,8 @@ class SLPGraph {
struct GreedySLPVectorizerPass
: public mlir::vector::impl::GreedySLPVectorizerBase<
GreedySLPVectorizerPass> {
+ using GreedySLPVectorizerBase::GreedySLPVectorizerBase;
+
void runOnOperation() override;
};
@@ -1064,7 +1066,7 @@ void GreedySLPVectorizerPass::runOnOperation() {
auto sizeInBits = dataLayout->getTypeSizeInBits(type);
- return sizeInBits * count <= 256;
+ return sizeInBits * count <= this->maxVectorBitwidth;
};
// Run until fixed point is reached.
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 0a40037b015b1..262db81e16f21 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --greedy-slp-vectorizer | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(greedy-slp-vectorizer{max-vector-bitwidth=256}))' | FileCheck %s
// CHECK-LABEL: func @negative_single_op
>From f542644fe567f051d18882d08c24dac718943584 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 20:54:35 +0200
Subject: [PATCH 42/52] fix test name
---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 262db81e16f21..75b77561ed891 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -633,9 +633,9 @@ func.func @read_read_add_write_interleaved_use_add(%arg0: memref<8xi32>, %arg1:
}
-// CHECK-LABEL: func @negative_different_blocks
+// CHECK-LABEL: func @different_blocks
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
-func.func @negative_different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+func.func @different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
>From c89d7c6ee7b7168a628e2153f318421001b7a224 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 26 May 2025 21:04:44 +0200
Subject: [PATCH 43/52] cleanup
---
mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 0e650359d3339..07bba1093d741 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -1031,7 +1031,7 @@ tryToVectorizeInBlock(Block &block,
<< " operations\n";
}
});
- rootGroups.append(contiguousGroups.begin(), contiguousGroups.end());
+ rootGroups.append(contiguousGroups);
}
// Build the SLP graph from root groups
>From d8fabe64323f95288b2debec47b411ce055cfeac Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 27 May 2025 00:14:34 +0200
Subject: [PATCH 44/52] AffineApplyOp index support
---
.../Vector/Transforms/SLPVectorizer.cpp | 26 +++++++++++++++-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 31 +++++++++++++++++++
2 files changed, 56 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 07bba1093d741..892a8807d70e4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataLayoutAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -142,6 +143,27 @@ static Type getElementType(Operation *op) {
return {};
}
+static bool isAdjacentAffineMapIndices(Value idx1, Value idx2) {
+ auto applyOp1 = idx1.getDefiningOp<affine::AffineApplyOp>();
+ if (!applyOp1)
+ return false;
+
+ auto applyOp2 = idx2.getDefiningOp<affine::AffineApplyOp>();
+ if (!applyOp2)
+ return false;
+
+ if (applyOp1.getOperands() != applyOp2.getOperands())
+ return false;
+
+ AffineExpr expr1 = applyOp1.getAffineMap().getResult(0);
+ AffineExpr expr2 = applyOp2.getAffineMap().getResult(0);
+ auto diff =
+ simplifyAffineExpr(expr2 - expr1, 0, applyOp1.getOperands().size());
+
+ auto diffConst = dyn_cast<AffineConstantExpr>(diff);
+ return diffConst && diffConst.getValue() == 1;
+}
+
/// Check if two indices are consecutive, i.e index1 + 1 == index2.
static bool isAdjacentIndices(Value idx1, Value idx2) {
if (auto c1 = getConstantIntValue(idx1)) {
@@ -160,7 +182,9 @@ static bool isAdjacentIndices(Value idx1, Value idx2) {
}
}
- // TODO: Handle affine.apply, etc
+ if (isAdjacentAffineMapIndices(idx1, idx2))
+ return true;
+
return false;
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 75b77561ed891..4328926f8071f 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -143,6 +143,37 @@ func.func @read_write_add_index_interleaved(%arg0: memref<8xi32>, %arg1: memref<
}
+#map0 = affine_map<()[s0, s1] -> (s1 * s0)>
+#map1 = affine_map<()[s0, s1] -> (s1 * s0 + 1)>
+#map2 = affine_map<()[s0, s1] -> (s1 * s0 + 2)>
+#map3 = affine_map<()[s0, s1] -> (s1 * s0 + 3)>
+
+// CHECK-LABEL: func @read_write_affine_apply
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @read_write_affine_apply(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index, %arg3: index) {
+ // CHECK: %[[IDX:.*]] = affine.apply #{{.*}}()[%[[ARG2]], %[[ARG3]]]
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[IDX]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[IDX]]] : memref<8xi32>, vector<4xi32>
+
+ %ind0 = affine.apply #map0()[%arg2, %arg3]
+ %ind1 = affine.apply #map1()[%arg2, %arg3]
+ %ind2 = affine.apply #map2()[%arg2, %arg3]
+ %ind3 = affine.apply #map3()[%arg2, %arg3]
+
+ %0 = memref.load %arg0[%ind0] : memref<8xi32>
+ %1 = memref.load %arg0[%ind1] : memref<8xi32>
+ %2 = memref.load %arg0[%ind2] : memref<8xi32>
+ %3 = memref.load %arg0[%ind3] : memref<8xi32>
+
+ memref.store %0, %arg0[%ind0] : memref<8xi32>
+ memref.store %1, %arg0[%ind1] : memref<8xi32>
+ memref.store %2, %arg0[%ind2] : memref<8xi32>
+ memref.store %3, %arg0[%ind3] : memref<8xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32) {
>From c6021c52d52d23affe377ab80350a3aef019fc76 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 27 May 2025 13:12:33 +0200
Subject: [PATCH 45/52] fix offset
---
mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp | 8 +++++---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 8 ++++----
2 files changed, 9 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 892a8807d70e4..81aa63b31bdc3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -965,13 +965,13 @@ SLPGraph::vectorize(IRRewriter &rewriter,
}
};
- auto handleVecSizeMismatch = [&](Value arg) -> Value {
+ auto handleVecSizeMismatch = [&](Value arg, int64_t offset = 0) -> Value {
auto srcType = cast<VectorType>(arg.getType());
assert(srcType.getRank() == 1);
if (srcType.getDimSize(0) == numElements)
return arg;
- return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, 0,
+ return rewriter.create<vector::ExtractStridedSliceOp>(loc, arg, offset,
numElements, 1);
};
@@ -1007,7 +1007,9 @@ SLPGraph::vectorize(IRRewriter &rewriter,
mapping.map(op->getResults(), newOp->getResults());
handleNonVectorOutputs(newOp->getResult(0));
} else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
- Value val = handleVecSizeMismatch(extract.getVector());
+ // We alredy verified index is valid during graph construction.
+ int64_t offset = *getExtractIndex(extract);
+ Value val = handleVecSizeMismatch(extract.getVector(), offset);
mapping.map(extract.getResult(), val);
} else {
op->emitError("unsupported operation");
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 4328926f8071f..e339eb5755bd6 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -380,9 +380,9 @@ func.func @read_read_add_write_attrs_mismatch(%arg0: memref<8xi32>, %arg1: memre
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
- // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
// CHECK: %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
- // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
// CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
// CHECK: %[[V5:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
// CHECK: %[[V6:.*]] = arith.addi %[[V4]], %[[V5]] overflow<nsw> : vector<2xi32>
@@ -670,9 +670,9 @@ func.func @different_blocks(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
- // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
// CHECK: %[[V2:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
- // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
// CHECK: cf.br ^bb1
// CHECK: ^bb1:
// CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
>From 74e1d80def3a0191e6e4b340b2d67ede682f4618 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 31 May 2025 21:14:51 +0200
Subject: [PATCH 46/52] support for 1-element vectors
---
.../Vector/Transforms/SLPVectorizer.cpp | 126 +++++++++++++-----
mlir/test/Dialect/Vector/slp-vectorize.mlir | 74 ++++++++++
2 files changed, 168 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 81aa63b31bdc3..a9411c7c903bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -68,6 +68,32 @@ static bool maybeWriteOp(Operation *op) {
return effectInterface.hasEffect<MemoryEffects::Write>();
}
+static Type getVectorElementType(VectorType vectorType) {
+ if (vectorType.getRank() > 1 || vectorType.isScalable() ||
+ vectorType.getNumElements() != 1)
+ return {};
+
+ return vectorType.getElementType();
+}
+
+static Type getElementType(Operation *op) {
+ assert(op && "null op");
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getResult().getType();
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getValueToStore().getType();
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+ return getVectorElementType(loadOp.getVectorType());
+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+ return getVectorElementType(storeOp.getVectorType());
+ return {};
+}
+
+static bool isSupportedMemOp(Operation *op) {
+ assert(op && "null op");
+ return isa_and_present<IntegerType, FloatType, IndexType>(getElementType(op));
+}
+
/// Collect all memory operations in the block into groups.
/// Each group contains either all loads or all stores, uninterrupted by
/// operations of the other type.
@@ -85,7 +111,7 @@ static SmallVector<MemoryOpGroup> collectMemoryOpGroups(Block &block) {
}
}
- if (!isa<memref::LoadOp, memref::StoreOp>(op))
+ if (!isSupportedMemOp(&op))
continue;
bool isLoad = maybeReadOp(&op);
@@ -109,6 +135,19 @@ static Value getBase(Operation *op) {
return loadOp.getMemRef();
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
return storeOp.getMemRef();
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+ return loadOp.getBase();
+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+ return storeOp.getBase();
+ return {};
+}
+
+static Value getValueToStore(Operation *op) {
+ assert(op && "null op");
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getValueToStore();
+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+ return storeOp.getValueToStore();
return {};
}
@@ -131,15 +170,10 @@ static ValueRange getIndices(Operation *op) {
return loadOp.getIndices();
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
return storeOp.getIndices();
- return {};
-}
-
-static Type getElementType(Operation *op) {
- assert(op && "null op");
- if (auto loadOp = dyn_cast<memref::LoadOp>(op))
- return loadOp.getResult().getType();
- if (auto storeOp = dyn_cast<memref::StoreOp>(op))
- return storeOp.getValueToStore().getType();
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+ return loadOp.getIndices();
+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+ return storeOp.getIndices();
return {};
}
@@ -285,7 +319,15 @@ static bool isVectorizable(Operation *op) {
for (auto type :
llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
- if (!type.isIntOrIndexOrFloat())
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
+ if (vectorType.getRank() > 1 || vectorType.isScalable() ||
+ vectorType.getNumElements() != 1)
+ return false;
+
+ type = vectorType.getElementType();
+ }
+
+ if (!isa<IntegerType, FloatType, IndexType>(type))
return false;
}
@@ -464,8 +506,7 @@ class SLPGraph {
for (const auto &node : nodes) {
if (!node->isRoot)
continue;
- llvm::dbgs() << " "
- << (isa<memref::LoadOp>(node->op()) ? "LOAD" : "STORE")
+ llvm::dbgs() << " " << (maybeReadOp(node->op()) ? "LOAD" : "STORE")
<< " group with " << node->size() << " operations:\n";
for (auto *op : node->ops) {
llvm::dbgs() << " " << *op << "\n";
@@ -657,20 +698,36 @@ checkOpVecType(SLPGraphNode *node,
llvm::function_ref<bool(Type, size_t)> isValidVecType) {
Operation *op = node->op();
size_t size = node->size();
- if (Type elementType = getElementType(op))
- return isValidVecType(elementType, size);
+ auto checkRes = [](bool res) -> bool {
+ LLVM_DEBUG(llvm::dbgs() << (res ? "true" : "false") << "\n");
+ return res;
+ };
+
+ if (Type elementType = getElementType(op)) {
+ LLVM_DEBUG(llvm::dbgs() << "Checking if type " << elementType
+ << " with size " << size << " can be vectorized: ");
+ return checkRes(isValidVecType(elementType, size));
+ }
if (isVectorizable(op)) {
for (auto type :
llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
- if (!isValidVecType(type, size))
+ Type elementType = getElementTypeOrSelf(type);
+ LLVM_DEBUG(llvm::dbgs()
+ << "Checking if type " << elementType << " with size " << size
+ << " can be vectorized: ");
+ if (!checkRes(isValidVecType(elementType, size)))
return false;
}
return true;
}
- if (auto extract = dyn_cast<vector::ExtractOp>(op))
- return isValidVecType(extract.getResult().getType(), size);
+ if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
+ Type type = extract.getResult().getType();
+ LLVM_DEBUG(llvm::dbgs() << "Checking if type " << type << " with size "
+ << size << " can be vectorized: ");
+ return checkRes(isValidVecType(type, size));
+ }
LLVM_DEBUG(llvm::dbgs() << "Unsupported op " << op->getName() << "\n");
return false;
@@ -903,12 +960,19 @@ SLPGraph::vectorize(IRRewriter &rewriter,
for (auto *operand : node->operands)
size = std::min(size, operand->size());
- node->ops.resize(size);
+ if (size < node->size()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Size mismatch, resizing node with " << node->size()
+ << " operations to " << size << "\n");
+ node->ops.resize(size);
+ }
while (node->size() > 1) {
if (checkOpVecType(node, isValidVecType))
break;
+ LLVM_DEBUG(llvm::dbgs() << "No a valid vector type, popping back op: "
+ << node->ops.back()->getName() << "\n");
node->ops.pop_back();
}
}
@@ -975,24 +1039,22 @@ SLPGraph::vectorize(IRRewriter &rewriter,
numElements, 1);
};
- if (auto load = dyn_cast<memref::LoadOp>(op)) {
- auto vecType =
- VectorType::get(numElements, load.getMemRefType().getElementType());
- Value result = rewriter.create<vector::LoadOp>(
- loc, vecType, load.getMemRef(), load.getIndices());
- mapping.map(load.getResult(), result);
+ if (maybeReadOp(op)) {
+ auto vecType = VectorType::get(numElements, getElementType(op));
+ Value result = rewriter.create<vector::LoadOp>(loc, vecType, getBase(op),
+ getIndices(op));
+ mapping.map(op->getResult(0), result);
handleNonVectorOutputs(result);
- } else if (auto store = dyn_cast<memref::StoreOp>(op)) {
- handleNonVectorInputs(store.getValueToStore());
- Value val = mapping.lookupOrDefault(store.getValueToStore());
+ } else if (maybeWriteOp(op)) {
+ handleNonVectorInputs(getValueToStore(op));
+ Value val = mapping.lookupOrDefault(getValueToStore(op));
val = handleVecSizeMismatch(val);
- rewriter.create<vector::StoreOp>(loc, val, store.getMemRef(),
- store.getIndices());
+ rewriter.create<vector::StoreOp>(loc, val, getBase(op), getIndices(op));
} else if (isVectorizable(op)) {
handleNonVectorInputs(op->getOperands());
Operation *newOp = rewriter.clone(*op, mapping);
- auto resVectorType =
- VectorType::get(numElements, op->getResultTypes().front());
+ Type resType = getElementTypeOrSelf(op->getResultTypes().front());
+ auto resVectorType = VectorType::get(numElements, resType);
{
OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index e339eb5755bd6..aeedececa1a7c 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -276,6 +276,80 @@ func.func @read_read_add_write(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
}
+// CHECK-LABEL: func @read_read_add_write_vec_0d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_vec_0d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<i32>
+ %1 = vector.load %arg0[%c1] : memref<8xi32>, vector<i32>
+ %2 = vector.load %arg0[%c2] : memref<8xi32>, vector<i32>
+ %3 = vector.load %arg0[%c3] : memref<8xi32>, vector<i32>
+
+ %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<i32>
+ %5 = vector.load %arg1[%c1] : memref<8xi32>, vector<i32>
+ %6 = vector.load %arg1[%c2] : memref<8xi32>, vector<i32>
+ %7 = vector.load %arg1[%c3] : memref<8xi32>, vector<i32>
+
+ %8 = arith.addi %0, %4 : vector<i32>
+ %9 = arith.addi %1, %5 : vector<i32>
+ %10 = arith.addi %2, %6 : vector<i32>
+ %11 = arith.addi %3, %7 : vector<i32>
+
+ vector.store %8, %arg0[%c0] : memref<8xi32>, vector<i32>
+ vector.store %9, %arg0[%c1] : memref<8xi32>, vector<i32>
+ vector.store %10, %arg0[%c2] : memref<8xi32>, vector<i32>
+ vector.store %11, %arg0[%c3] : memref<8xi32>, vector<i32>
+
+ return
+}
+
+
+// CHECK-LABEL: func @read_read_add_write_vec_1d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_vec_1d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<1xi32>
+ %1 = vector.load %arg0[%c1] : memref<8xi32>, vector<1xi32>
+ %2 = vector.load %arg0[%c2] : memref<8xi32>, vector<1xi32>
+ %3 = vector.load %arg0[%c3] : memref<8xi32>, vector<1xi32>
+
+ %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<1xi32>
+ %5 = vector.load %arg1[%c1] : memref<8xi32>, vector<1xi32>
+ %6 = vector.load %arg1[%c2] : memref<8xi32>, vector<1xi32>
+ %7 = vector.load %arg1[%c3] : memref<8xi32>, vector<1xi32>
+
+ %8 = arith.addi %0, %4 : vector<1xi32>
+ %9 = arith.addi %1, %5 : vector<1xi32>
+ %10 = arith.addi %2, %6 : vector<1xi32>
+ %11 = arith.addi %3, %7 : vector<1xi32>
+
+ vector.store %8, %arg0[%c0] : memref<8xi32>, vector<1xi32>
+ vector.store %9, %arg0[%c1] : memref<8xi32>, vector<1xi32>
+ vector.store %10, %arg0[%c2] : memref<8xi32>, vector<1xi32>
+ vector.store %11, %arg0[%c3] : memref<8xi32>, vector<1xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add_write_seven
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xindex>, %[[ARG1:.*]]: memref<8xindex>)
func.func @read_read_add_write_seven(%arg0: memref<8xindex>, %arg1: memref<8xindex>) {
>From 5e473ab19148652d7688a43a51447e3880a239e9 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 31 May 2025 21:23:44 +0200
Subject: [PATCH 47/52] refac size() -> opsCount()
---
.../Vector/Transforms/SLPVectorizer.cpp | 45 ++++++++++---------
1 file changed, 24 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index a9411c7c903bd..e2b3156bd8f6c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -49,7 +49,7 @@ struct MemoryOpGroup {
bool isLoadGroup() const { return type == Type::Load; }
bool isStoreGroup() const { return type == Type::Store; }
- size_t size() const { return ops.size(); }
+ size_t opsCount() const { return ops.size(); }
};
static bool maybeReadOp(Operation *op) {
@@ -305,7 +305,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
}
LLVM_DEBUG(llvm::dbgs() << "Extracted contiguous group with "
- << currentGroup.size() << " operations\n");
+ << currentGroup.opsCount() << " operations\n");
}
return result;
}
@@ -353,7 +353,7 @@ struct SLPGraphNode {
SLPGraphNode(ArrayRef<Operation *> operations)
: ops(operations.begin(), operations.end()) {}
- size_t size() const { return ops.size(); }
+ size_t opsCount() const { return ops.size(); }
Operation *op() const {
assert(!ops.empty() && "empty ops");
@@ -507,13 +507,14 @@ class SLPGraph {
if (!node->isRoot)
continue;
llvm::dbgs() << " " << (maybeReadOp(node->op()) ? "LOAD" : "STORE")
- << " group with " << node->size() << " operations:\n";
+ << " group with " << node->opsCount() << " operations:\n";
for (auto *op : node->ops) {
llvm::dbgs() << " " << *op << "\n";
}
llvm::dbgs() << " Users: ";
for (auto *user : node->users) {
- llvm::dbgs() << "\n Group with " << user->size() << " operations:";
+ llvm::dbgs() << "\n Group with " << user->opsCount()
+ << " operations:";
for (auto *op : user->ops) {
llvm::dbgs() << "\n " << *op;
}
@@ -526,13 +527,13 @@ class SLPGraph {
for (const auto &node : nodes) {
if (node->isRoot)
continue;
- llvm::dbgs() << " Group with " << node->size() << " operations:\n";
+ llvm::dbgs() << " Group with " << node->opsCount() << " operations:\n";
for (auto *op : node->ops) {
llvm::dbgs() << " " << *op << "\n";
}
llvm::dbgs() << " Operands: ";
for (auto *operand : node->operands) {
- llvm::dbgs() << "\n Group with " << operand->size()
+ llvm::dbgs() << "\n Group with " << operand->opsCount()
<< " operations:";
for (auto *op : operand->ops) {
llvm::dbgs() << "\n " << *op;
@@ -540,7 +541,8 @@ class SLPGraph {
}
llvm::dbgs() << "\n Users: ";
for (auto *user : node->users) {
- llvm::dbgs() << "\n Group with " << user->size() << " operations:";
+ llvm::dbgs() << "\n Group with " << user->opsCount()
+ << " operations:";
for (auto *op : user->ops) {
llvm::dbgs() << "\n " << *op;
}
@@ -697,7 +699,7 @@ static bool
checkOpVecType(SLPGraphNode *node,
llvm::function_ref<bool(Type, size_t)> isValidVecType) {
Operation *op = node->op();
- size_t size = node->size();
+ size_t size = node->opsCount();
auto checkRes = [](bool res) -> bool {
LLVM_DEBUG(llvm::dbgs() << (res ? "true" : "false") << "\n");
return res;
@@ -779,7 +781,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
worklist.push_back(node);
LLVM_DEBUG({
- llvm::dbgs() << "Created root group node with " << node->size()
+ llvm::dbgs() << "Created root group node with " << node->opsCount()
<< " operations of type "
<< (group.isLoadGroup() ? "Load" : "Store") << "\n";
});
@@ -907,7 +909,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
while (!worklist.empty()) {
SLPGraphNode *node = worklist.pop_back_val();
LLVM_DEBUG(llvm::dbgs()
- << "Processing node with " << node->size()
+ << "Processing node with " << node->opsCount()
<< " operations, first op: " << node->op()->getName() << "\n");
Operation *op = node->op();
@@ -940,7 +942,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
LLVM_DEBUG({
llvm::dbgs() << "Topologically sorted nodes:\n";
for (auto *node : sortedNodes) {
- llvm::dbgs() << " Node with " << node->size()
+ llvm::dbgs() << " Node with " << node->opsCount()
<< " operations: " << node->op()->getName() << "\n";
}
});
@@ -948,7 +950,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
auto isBadNode = [&](SLPGraphNode *node) {
// Do not vectorize stray nodes which are not connected to any other
// nodes.
- return (node->users.empty() && node->operands.empty()) || node->size() <= 1;
+ return (node->users.empty() && node->operands.empty()) ||
+ node->opsCount() <= 1;
};
// Update node vec sizes if its inputs vec sizes are smaller.
@@ -956,18 +959,18 @@ SLPGraph::vectorize(IRRewriter &rewriter,
// TODO: It maybe possible to reconstruct the larger vec size combining src
// smaller vector and scalar arg.
for (auto *node : sortedNodes) {
- size_t size = node->size();
+ size_t size = node->opsCount();
for (auto *operand : node->operands)
- size = std::min(size, operand->size());
+ size = std::min(size, operand->opsCount());
- if (size < node->size()) {
+ if (size < node->opsCount()) {
LLVM_DEBUG(llvm::dbgs()
- << "Size mismatch, resizing node with " << node->size()
+ << "Size mismatch, resizing node with " << node->opsCount()
<< " operations to " << size << "\n");
node->ops.resize(size);
}
- while (node->size() > 1) {
+ while (node->opsCount() > 1) {
if (checkOpVecType(node, isValidVecType))
break;
@@ -982,7 +985,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
IRMapping mapping;
for (auto *node : sortedNodes) {
LLVM_DEBUG({
- llvm::dbgs() << "Processing node with " << node->size()
+ llvm::dbgs() << "Processing node with " << node->opsCount()
<< " operations\n";
llvm::dbgs() << " First op: " << *node->op() << "\n";
});
@@ -997,7 +1000,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
LLVM_DEBUG(llvm::dbgs() << " Insertion point: " << *ip << "\n");
rewriter.setInsertionPoint(ip);
- int64_t numElements = node->size();
+ int64_t numElements = node->opsCount();
Location loc = op->getLoc();
auto handleNonVectorInputs = [&](ValueRange operands) {
@@ -1115,7 +1118,7 @@ tryToVectorizeInBlock(Block &block,
<< " contiguous groups in "
<< (group.isLoadGroup() ? "load" : "store") << " group\n";
for (const auto &contigGroup : contiguousGroups) {
- llvm::dbgs() << " Contiguous group with " << contigGroup.size()
+ llvm::dbgs() << " Contiguous group with " << contigGroup.opsCount()
<< " operations\n";
}
});
>From 4161f5a63af8e950ba8150eafeb4439ecf5d8bb3 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 31 May 2025 22:37:50 +0200
Subject: [PATCH 48/52] merge vectorized ops too
---
.../Vector/Transforms/SLPVectorizer.cpp | 129 ++++++++++++------
mlir/test/Dialect/Vector/slp-vectorize.mlir | 34 ++---
2 files changed, 105 insertions(+), 58 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index e2b3156bd8f6c..c010193226814 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -43,6 +43,7 @@ struct MemoryOpGroup {
enum class Type { Load, Store };
Type type;
SmallVector<Operation *> ops;
+ int64_t elementsCount = 0;
MemoryOpGroup(Type t) : type(t) {}
@@ -68,30 +69,37 @@ static bool maybeWriteOp(Operation *op) {
return effectInterface.hasEffect<MemoryEffects::Write>();
}
-static Type getVectorElementType(VectorType vectorType) {
- if (vectorType.getRank() > 1 || vectorType.isScalable() ||
- vectorType.getNumElements() != 1)
- return {};
+static std::optional<std::pair<Type, int64_t>>
+getVectorElementTypeAndCount(VectorType vectorType) {
+ if (vectorType.getRank() > 1 || vectorType.isScalable())
+ return std::nullopt;
- return vectorType.getElementType();
+ return std::make_pair(vectorType.getElementType(),
+ vectorType.getNumElements());
}
-static Type getElementType(Operation *op) {
+static std::optional<std::pair<Type, int64_t>>
+getElementTypeAndCount(Operation *op) {
assert(op && "null op");
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
- return loadOp.getResult().getType();
+ return std::make_pair(loadOp.getResult().getType(), 1);
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
- return storeOp.getValueToStore().getType();
+ return std::make_pair(storeOp.getValueToStore().getType(), 1);
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
- return getVectorElementType(loadOp.getVectorType());
+ return getVectorElementTypeAndCount(loadOp.getVectorType());
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
- return getVectorElementType(storeOp.getVectorType());
- return {};
+ return getVectorElementTypeAndCount(storeOp.getVectorType());
+ return std::nullopt;
}
static bool isSupportedMemOp(Operation *op) {
assert(op && "null op");
- return isa_and_present<IntegerType, FloatType, IndexType>(getElementType(op));
+ auto typeAndCount = getElementTypeAndCount(op);
+ if (!typeAndCount)
+ return false;
+
+ return isa_and_present<IntegerType, FloatType, IndexType>(
+ typeAndCount->first);
}
/// Collect all memory operations in the block into groups.
@@ -177,7 +185,7 @@ static ValueRange getIndices(Operation *op) {
return {};
}
-static bool isAdjacentAffineMapIndices(Value idx1, Value idx2) {
+static bool isAdjacentAffineMapIndices(Value idx1, Value idx2, int64_t offset) {
auto applyOp1 = idx1.getDefiningOp<affine::AffineApplyOp>();
if (!applyOp1)
return false;
@@ -195,28 +203,29 @@ static bool isAdjacentAffineMapIndices(Value idx1, Value idx2) {
simplifyAffineExpr(expr2 - expr1, 0, applyOp1.getOperands().size());
auto diffConst = dyn_cast<AffineConstantExpr>(diff);
- return diffConst && diffConst.getValue() == 1;
+ return diffConst && diffConst.getValue() == offset;
}
/// Check if two indices are consecutive, i.e index1 + 1 == index2.
-static bool isAdjacentIndices(Value idx1, Value idx2) {
+static bool isAdjacentIndices(Value idx1, Value idx2, int64_t offset) {
if (auto c1 = getConstantIntValue(idx1)) {
if (auto c2 = getConstantIntValue(idx2))
- return *c1 + 1 == *c2;
+ return *c1 + offset == *c2;
}
if (auto addOp2 = idx2.getDefiningOp<arith::AddIOp>()) {
- if (addOp2.getLhs() == idx1 && getConstantIntValue(addOp2.getRhs()) == 1)
+ if (addOp2.getLhs() == idx1 &&
+ getConstantIntValue(addOp2.getRhs()) == offset)
return true;
if (auto addOp1 = idx1.getDefiningOp<arith::AddIOp>()) {
if (addOp1.getLhs() == addOp2.getLhs() &&
- isAdjacentIndices(addOp1.getRhs(), addOp2.getRhs()))
+ isAdjacentIndices(addOp1.getRhs(), addOp2.getRhs(), offset))
return true;
}
}
- if (isAdjacentAffineMapIndices(idx1, idx2))
+ if (isAdjacentAffineMapIndices(idx1, idx2, offset))
return true;
return false;
@@ -224,19 +233,22 @@ static bool isAdjacentIndices(Value idx1, Value idx2) {
/// Check if two ranges of indices are consecutive, i.e fastest index differs
/// by 1 and all other indices are the same.
-static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2) {
+static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2,
+ int64_t offset) {
if (idx1.empty() || idx1.size() != idx2.size())
return false;
if (idx1.drop_back() != idx2.drop_back())
return false;
- return isAdjacentIndices(idx1.back(), idx2.back());
+ return isAdjacentIndices(idx1.back(), idx2.back(), offset);
}
/// Check if two operations are adjacent and can be combined into a vector op.
/// This is done by checking if the base memrefs are the same, the last
-/// dimension is contiguous, and the element types and indices are compatible
+/// dimension is contiguous, and the element types and indices are compatible.
+/// If source read/write is already vectorized, only merge ops if vector
+/// elements count is the same.
static bool isAdjacentOps(Operation *op1, Operation *op2) {
assert(op1 && "null op1");
assert(op2 && "null op2");
@@ -249,10 +261,19 @@ static bool isAdjacentOps(Operation *op1, Operation *op2) {
if (!isContiguousLastDim(base1))
return false;
- if (getElementType(op1) != getElementType(op2))
+ auto typeAndCount1 = getElementTypeAndCount(op1);
+ if (!typeAndCount1)
+ return false;
+
+ auto typeAndCount2 = getElementTypeAndCount(op2);
+ if (!typeAndCount2)
return false;
- return isAdjacentIndices(getIndices(op1), getIndices(op2));
+ if (typeAndCount1 != typeAndCount2)
+ return false;
+
+ return isAdjacentIndices(getIndices(op1), getIndices(op2),
+ typeAndCount1->second);
}
// Extract contiguous groups from a MemoryOpGroup
@@ -271,6 +292,7 @@ extractContiguousGroups(const MemoryOpGroup &group) {
// Start a new group with this operation
result.emplace_back(group.type);
MemoryOpGroup ¤tGroup = result.back();
+ currentGroup.elementsCount = getElementTypeAndCount(op)->second;
auto ¤tOps = currentGroup.ops;
currentOps.push_back(op);
processedOps.insert(op);
@@ -310,7 +332,9 @@ extractContiguousGroups(const MemoryOpGroup &group) {
return result;
}
-static bool isVectorizable(Operation *op) {
+static bool
+isVectorizable(Operation *op,
+ std::optional<int64_t> expectedElementsCount = std::nullopt) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
@@ -319,14 +343,18 @@ static bool isVectorizable(Operation *op) {
for (auto type :
llvm::concat<Type>(op->getResultTypes(), op->getOperandTypes())) {
+ int64_t vectorElementsCount = 1;
if (auto vectorType = dyn_cast<VectorType>(type)) {
- if (vectorType.getRank() > 1 || vectorType.isScalable() ||
- vectorType.getNumElements() != 1)
+ if (vectorType.getRank() > 1 || vectorType.isScalable())
return false;
type = vectorType.getElementType();
+ vectorElementsCount = vectorType.getNumElements();
}
+ if (expectedElementsCount && vectorElementsCount != *expectedElementsCount)
+ return false;
+
if (!isa<IntegerType, FloatType, IndexType>(type))
return false;
}
@@ -347,6 +375,7 @@ struct SLPGraphNode {
SmallVector<SLPGraphNode *> users;
SmallVector<SLPGraphNode *> operands;
Operation *insertionPoint = nullptr;
+ int64_t elementsCount = 0;
bool isRoot = false;
SLPGraphNode() = default;
@@ -354,6 +383,7 @@ struct SLPGraphNode {
: ops(operations.begin(), operations.end()) {}
size_t opsCount() const { return ops.size(); }
+ size_t vectorSize() const { return elementsCount * opsCount(); }
Operation *op() const {
assert(!ops.empty() && "empty ops");
@@ -415,17 +445,20 @@ class SLPGraph {
SLPGraph &operator=(SLPGraph &&) = default;
/// Add a new node to the graph
- SLPGraphNode *addNode(ArrayRef<Operation *> operations) {
+ SLPGraphNode *addNode(ArrayRef<Operation *> operations,
+ int64_t elementsCount) {
nodes.push_back(std::make_unique<SLPGraphNode>(operations));
auto *node = nodes.back().get();
+ node->elementsCount = elementsCount;
for (Operation *op : operations)
opToNode[op] = node;
return node;
}
/// Add a root node (memory operation)
- SLPGraphNode *addRoot(ArrayRef<Operation *> operations) {
- auto *node = addNode(operations);
+ SLPGraphNode *addRoot(ArrayRef<Operation *> operations,
+ int64_t elementsCount) {
+ auto *node = addNode(operations, elementsCount);
node->isRoot = true;
return node;
}
@@ -699,13 +732,14 @@ static bool
checkOpVecType(SLPGraphNode *node,
llvm::function_ref<bool(Type, size_t)> isValidVecType) {
Operation *op = node->op();
- size_t size = node->opsCount();
+ size_t size = node->vectorSize();
auto checkRes = [](bool res) -> bool {
LLVM_DEBUG(llvm::dbgs() << (res ? "true" : "false") << "\n");
return res;
};
- if (Type elementType = getElementType(op)) {
+ if (auto typeAndCount = getElementTypeAndCount(op)) {
+ Type elementType = typeAndCount->first;
LLVM_DEBUG(llvm::dbgs() << "Checking if type " << elementType
<< " with size " << size << " can be vectorized: ");
return checkRes(isValidVecType(elementType, size));
@@ -777,7 +811,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
// First, create nodes for each contiguous memory operation group
for (const auto &group : rootGroups) {
- auto *node = graph.addRoot(group.ops);
+ auto *node = graph.addRoot(group.ops, group.elementsCount);
worklist.push_back(node);
LLVM_DEBUG({
@@ -800,7 +834,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
return;
}
- if (!isVectorizable(user))
+ if (!isVectorizable(user, node->elementsCount))
return;
Fingerprint expectedFingerprint = fingerprints.getFingerprint(user);
@@ -830,7 +864,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
if (currentOps.size() == 1)
return;
- auto *newNode = graph.addNode(currentOps);
+ auto *newNode = graph.addNode(currentOps, node->elementsCount);
graph.addEdge(node, newNode);
for (Operation *op : currentOps)
fingerprints.invalidate(op);
@@ -877,7 +911,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
currentOps.push_back(otherOp);
++currentIndex;
}
- } else if (isVectorizable(srcOp)) {
+ } else if (isVectorizable(srcOp, node->elementsCount)) {
LLVM_DEBUG(llvm::dbgs() << " Processing vectorizable op "
<< srcOp->getName() << "\n");
@@ -898,7 +932,7 @@ static SLPGraph buildSLPGraph(ArrayRef<MemoryOpGroup> rootGroups) {
if (currentOps.size() == 1)
return;
- auto *newNode = graph.addNode(currentOps);
+ auto *newNode = graph.addNode(currentOps, node->elementsCount);
graph.addEdge(newNode, node);
for (Operation *op : currentOps)
fingerprints.invalidate(op);
@@ -1000,7 +1034,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
LLVM_DEBUG(llvm::dbgs() << " Insertion point: " << *ip << "\n");
rewriter.setInsertionPoint(ip);
- int64_t numElements = node->opsCount();
+ int64_t numElements = node->vectorSize();
Location loc = op->getLoc();
auto handleNonVectorInputs = [&](ValueRange operands) {
@@ -1009,10 +1043,20 @@ SLPGraph::vectorize(IRRewriter &rewriter,
continue;
SmallVector<Value> args;
- for (Operation *defOp : node->ops)
- args.push_back(defOp->getOperand(i));
+ for (Operation *defOp : node->ops) {
+ Value arg = defOp->getOperand(i);
+ if (auto vecType = dyn_cast<VectorType>(arg.getType())) {
+ assert(vecType.getRank() == 1);
+ for (auto j : llvm::seq(vecType.getNumElements()))
+ args.push_back(rewriter.create<vector::ExtractOp>(loc, arg, j));
+
+ } else {
+ args.push_back(arg);
+ }
+ }
- auto vecType = VectorType::get(numElements, operand.getType());
+ auto vecType = VectorType::get(numElements,
+ getElementTypeOrSelf(operand.getType()));
Value vector =
rewriter.create<vector::FromElementsOp>(loc, vecType, args);
mapping.map(operand, vector);
@@ -1043,7 +1087,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
};
if (maybeReadOp(op)) {
- auto vecType = VectorType::get(numElements, getElementType(op));
+ auto vecType =
+ VectorType::get(numElements, getElementTypeAndCount(op)->first);
Value result = rewriter.create<vector::LoadOp>(loc, vecType, getBase(op),
getIndices(op));
mapping.map(op->getResult(0), result);
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index aeedececa1a7c..598ba5c755ab1 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -646,22 +646,24 @@ func.func private @use(i32)
// CHECK-LABEL: func @read_read_add_write_interleaved_use
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
func.func @read_read_add_write_interleaved_use(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[C3:.*]] = arith.constant 3 : index
- // CHECK: %[[V0:.*]] = memref.load %arg0[%[[C3]]] : memref<8xi32>
- // CHECK: %[[V1:.*]] = memref.load %arg1[%[[C3]]] : memref<8xi32>
- // CHECK: call @use(%[[V0]]) : (i32) -> ()
- // CHECK: %[[V2:.*]] = vector.load %arg0[%[[C0]]] : memref<8xi32>, vector<2xi32>
- // CHECK: %[[V3:.*]] = vector.load %arg1[%[[C0]]] : memref<8xi32>, vector<2xi32>
- // CHECK: %[[V4:.*]] = memref.load %arg0[%[[C2]]] : memref<8xi32>
- // CHECK: %[[V5:.*]] = memref.load %arg1[%[[C2]]] : memref<8xi32>
- // CHECK: %[[V6:.*]] = vector.from_elements %[[V4]], %[[V0]] : vector<2xi32>
- // CHECK: %[[V7:.*]] = vector.from_elements %[[V5]], %[[V1]] : vector<2xi32>
- // CHECK: %[[V8:.*]] = arith.addi %[[V6]], %[[V7]] : vector<2xi32>
- // CHECK: %[[V9:.*]] = arith.addi %[[V2]], %[[V3]] : vector<2xi32>
- // CHECK: vector.store %[[V9]], %arg0[%[[C0]]] : memref<8xi32>, vector<2xi32>
- // CHECK: vector.store %[[V8]], %arg0[%[[C2]]] : memref<8xi32>, vector<2xi32>
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[V0:.*]] = memref.load %[[ARG0]][%[[C3]]] : memref<8xi32>
+ // CHECK: %[[V1:.*]] = memref.load %[[ARG1]][%[[C3]]] : memref<8xi32>
+ // CHECK: call @use(%[[V0]]) : (i32) -> ()
+ // CHECK: %[[V2:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: %[[V3:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: %[[V4:.*]] = memref.load %[[ARG0]][%[[C2]]] : memref<8xi32>
+ // CHECK: %[[V5:.*]] = memref.load %[[ARG1]][%[[C2]]] : memref<8xi32>
+ // CHECK: %[[V6:.*]] = vector.extract %[[V2]][0] : i32 from vector<2xi32>
+ // CHECK: %[[V7:.*]] = vector.extract %[[V2]][1] : i32 from vector<2xi32>
+ // CHECK: %[[V8:.*]] = vector.from_elements %[[V6]], %[[V7]], %[[V4]], %[[V0]] : vector<4xi32>
+ // CHECK: %[[V9:.*]] = vector.extract %[[V3]][0] : i32 from vector<2xi32>
+ // CHECK: %[[V10:.*]] = vector.extract %[[V3]][1] : i32 from vector<2xi32>
+ // CHECK: %[[V11:.*]] = vector.from_elements %[[V9]], %[[V10]], %[[V5]], %[[V1]] : vector<4xi32>
+ // CHECK: %[[V12:.*]] = arith.addi %[[V8]], %[[V11]] : vector<4xi32>
+ // CHECK: vector.store %[[V12]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
>From 0e557a7e5965d2c7b11f65ffd88852828b8ca29e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 1 Jun 2025 12:05:59 +0200
Subject: [PATCH 49/52] more tests
---
mlir/test/Dialect/Vector/slp-vectorize.mlir | 32 +++++++++++++++++++++
1 file changed, 32 insertions(+)
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 598ba5c755ab1..a6108287551f4 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -350,6 +350,38 @@ func.func @read_read_add_write_vec_1d(%arg0: memref<8xi32>, %arg1: memref<8xi32>
}
+// CHECK-LABEL: func @read_read_add_write_mixed_vecs
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_write_mixed_vecs(%arg0: memref<8xi32>, %arg1: memref<8xi32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[A:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[B:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[RES:.*]] = arith.addi %[[A]], %[[B]] : vector<4xi32>
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+
+ %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<2xi32>
+ %2 = memref.load %arg0[%c2] : memref<8xi32>
+ %3 = vector.load %arg0[%c3] : memref<8xi32>, vector<1xi32>
+
+ %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<2xi32>
+ %6 = memref.load %arg1[%c2] : memref<8xi32>
+ %7 = vector.load %arg1[%c3] : memref<8xi32>, vector<1xi32>
+
+ %8 = arith.addi %0, %4 : vector<2xi32>
+ %10 = arith.addi %2, %6 : i32
+ %11 = arith.addi %3, %7 : vector<1xi32>
+
+ vector.store %8, %arg0[%c0] : memref<8xi32>, vector<2xi32>
+ memref.store %10, %arg0[%c2] : memref<8xi32>
+ vector.store %11, %arg0[%c3] : memref<8xi32>, vector<1xi32>
+
+ return
+}
+
+
// CHECK-LABEL: func @read_read_add_write_seven
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xindex>, %[[ARG1:.*]]: memref<8xindex>)
func.func @read_read_add_write_seven(%arg0: memref<8xindex>, %arg1: memref<8xindex>) {
>From 8e2d6118a7d82c1106088ac24f5d745bff0958cc Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 1 Jun 2025 12:35:38 +0200
Subject: [PATCH 50/52] comments
---
.../Vector/Transforms/SLPVectorizer.cpp | 33 +++++++++++++++----
1 file changed, 26 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index c010193226814..3af4b9dfb4ce4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -89,6 +89,7 @@ getElementTypeAndCount(Operation *op) {
return getVectorElementTypeAndCount(loadOp.getVectorType());
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
return getVectorElementTypeAndCount(storeOp.getVectorType());
+
return std::nullopt;
}
@@ -147,7 +148,8 @@ static Value getBase(Operation *op) {
return loadOp.getBase();
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
return storeOp.getBase();
- return {};
+
+ llvm_unreachable("unsupported op");
}
static Value getValueToStore(Operation *op) {
@@ -156,7 +158,8 @@ static Value getValueToStore(Operation *op) {
return storeOp.getValueToStore();
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
return storeOp.getValueToStore();
- return {};
+
+ llvm_unreachable("unsupported op");
}
static bool isContiguousLastDim(Value val) {
@@ -182,7 +185,8 @@ static ValueRange getIndices(Operation *op) {
return loadOp.getIndices();
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
return storeOp.getIndices();
- return {};
+
+ llvm_unreachable("unsupported op");
}
static bool isAdjacentAffineMapIndices(Value idx1, Value idx2, int64_t offset) {
@@ -206,7 +210,7 @@ static bool isAdjacentAffineMapIndices(Value idx1, Value idx2, int64_t offset) {
return diffConst && diffConst.getValue() == offset;
}
-/// Check if two indices are consecutive, i.e index1 + 1 == index2.
+/// Check if two indices are consecutive, i.e index1 + offset == index2.
static bool isAdjacentIndices(Value idx1, Value idx2, int64_t offset) {
if (auto c1 = getConstantIntValue(idx1)) {
if (auto c2 = getConstantIntValue(idx2))
@@ -232,7 +236,7 @@ static bool isAdjacentIndices(Value idx1, Value idx2, int64_t offset) {
}
/// Check if two ranges of indices are consecutive, i.e fastest index differs
-/// by 1 and all other indices are the same.
+/// by `offset` and all other indices are the same.
static bool isAdjacentIndices(ValueRange idx1, ValueRange idx2,
int64_t offset) {
if (idx1.empty() || idx1.size() != idx2.size())
@@ -272,6 +276,7 @@ static bool isAdjacentOps(Operation *op1, Operation *op2) {
if (typeAndCount1 != typeAndCount2)
return false;
+ // For now we are only merging ops with same elements count.
return isAdjacentIndices(getIndices(op1), getIndices(op2),
typeAndCount1->second);
}
@@ -332,6 +337,9 @@ extractContiguousGroups(const MemoryOpGroup &group) {
return result;
}
+/// Check if an operation is vectorizable.
+/// If `expectedElementsCount` is provided, check if original op had the
+/// specified number of elements.
static bool
isVectorizable(Operation *op,
std::optional<int64_t> expectedElementsCount = std::nullopt) {
@@ -362,7 +370,8 @@ isVectorizable(Operation *op,
return true;
}
-/// Get the next operation in the block, assuming `op` is not a terminator.
+/// Get the next operation in the block, assuming `op` is not a terminator/last
+/// operation in the block.
static Operation *nextOp(Operation *op) {
assert(op && "null op");
auto it = op->getIterator();
@@ -390,6 +399,9 @@ struct SLPGraphNode {
return ops.front();
}
+ /// Get the suitable insertion point for the new vectorized op.
+ /// This method is trying to take into account operands insertions points too
+ /// to satisfy dominance relations.
Operation *getInsertionPoint() {
assert(!ops.empty() && "empty node");
if (insertionPoint)
@@ -1038,6 +1050,9 @@ SLPGraph::vectorize(IRRewriter &rewriter,
Location loc = op->getLoc();
auto handleNonVectorInputs = [&](ValueRange operands) {
+ // Handle the case when op operands are not vectorized or have smaller
+ // vector size, construct the vector from the scalar operands using
+ // FromElementsOp.
for (auto [i, operand] : llvm::enumerate(operands)) {
if (getNodeForOp(operand.getDefiningOp()))
continue;
@@ -1064,6 +1079,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
};
auto handleNonVectorOutputs = [&](Value newResult) {
+ // Handle the case when op results are not vectorized or have smaller
+ // vector size, extract the elements from the vector.
for (auto [i, result] : llvm::enumerate(node->ops)) {
for (OpOperand &use : result->getUses()) {
Operation *useOwner = use.getOwner();
@@ -1077,6 +1094,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
};
auto handleVecSizeMismatch = [&](Value arg, int64_t offset = 0) -> Value {
+ // Handle vector size misamatch between 2 vectorized nodes.
auto srcType = cast<VectorType>(arg.getType());
assert(srcType.getRank() == 1);
if (srcType.getDimSize(0) == numElements)
@@ -1117,7 +1135,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
mapping.map(op->getResults(), newOp->getResults());
handleNonVectorOutputs(newOp->getResult(0));
} else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
- // We alredy verified index is valid during graph construction.
+ // We alredy verified index is valid during graph construction, so
+ // do need to check `getExtractIndex` result.
int64_t offset = *getExtractIndex(extract);
Value val = handleVecSizeMismatch(extract.getVector(), offset);
mapping.map(extract.getResult(), val);
>From 04cc9219d07e5f82be86f228d22f5ba4925b9f6d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 1 Jun 2025 12:48:06 +0200
Subject: [PATCH 51/52] vector outputs handling
---
.../Vector/Transforms/SLPVectorizer.cpp | 21 +++++++++++----
mlir/test/Dialect/Vector/slp-vectorize.mlir | 26 +++++++++++++++++++
2 files changed, 42 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 3af4b9dfb4ce4..9a389612567df 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -1078,7 +1078,8 @@ SLPGraph::vectorize(IRRewriter &rewriter,
}
};
- auto handleNonVectorOutputs = [&](Value newResult) {
+ auto handleNonVectorOutputs = [&](Value newResult,
+ Type originalResultType) {
// Handle the case when op results are not vectorized or have smaller
// vector size, extract the elements from the vector.
for (auto [i, result] : llvm::enumerate(node->ops)) {
@@ -1087,7 +1088,16 @@ SLPGraph::vectorize(IRRewriter &rewriter,
if (getNodeForOp(useOwner))
continue;
- Value elem = rewriter.create<vector::ExtractOp>(loc, newResult, i);
+ int64_t offset = i * node->elementsCount;
+ Value elem;
+
+ if (auto vecType = dyn_cast<VectorType>(originalResultType)) {
+ elem = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, newResult, offset, vecType.getNumElements(), 1);
+ } else {
+ elem = rewriter.create<vector::ExtractOp>(loc, newResult, offset);
+ }
+
use.set(elem);
}
}
@@ -1109,8 +1119,9 @@ SLPGraph::vectorize(IRRewriter &rewriter,
VectorType::get(numElements, getElementTypeAndCount(op)->first);
Value result = rewriter.create<vector::LoadOp>(loc, vecType, getBase(op),
getIndices(op));
- mapping.map(op->getResult(0), result);
- handleNonVectorOutputs(result);
+ Value originalResult = op->getResult(0);
+ mapping.map(originalResult, result);
+ handleNonVectorOutputs(result, originalResult.getType());
} else if (maybeWriteOp(op)) {
handleNonVectorInputs(getValueToStore(op));
Value val = mapping.lookupOrDefault(getValueToStore(op));
@@ -1133,7 +1144,7 @@ SLPGraph::vectorize(IRRewriter &rewriter,
newOp->getResult(0).setType(resVectorType);
mapping.map(op->getResults(), newOp->getResults());
- handleNonVectorOutputs(newOp->getResult(0));
+ handleNonVectorOutputs(newOp->getResult(0), op->getResultTypes().front());
} else if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
// We alredy verified index is valid during graph construction, so
// do need to check `getExtractIndex` result.
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index a6108287551f4..38490ba4934a4 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -672,6 +672,32 @@ func.func @read_read_add_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>,
}
+// CHECK-LABEL: func @read_read_add_add_vec
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_add_vec(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
+ (vector<2xi32>, vector<2xi32>){
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<4xi32>
+ // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<4xi32>
+ // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xi32> to vector<2xi32>
+ // CHECK: return %[[V3]], %[[V4]] : vector<2xi32>, vector<2xi32>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+
+ %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<2xi32>
+ %2 = vector.load %arg0[%c2] : memref<8xi32>, vector<2xi32>
+
+ %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<2xi32>
+ %6 = vector.load %arg1[%c2] : memref<8xi32>, vector<2xi32>
+
+ %8 = arith.addi %0, %4 : vector<2xi32>
+ %10 = arith.addi %2, %6 : vector<2xi32>
+
+ return %8, %10 : vector<2xi32>, vector<2xi32>
+}
+
func.func private @use(i32)
>From 8578695b9a84880319e50b0bfc5c591bca094cc5 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 1 Jun 2025 12:58:42 +0200
Subject: [PATCH 52/52] vector handling
---
.../Vector/Transforms/SLPVectorizer.cpp | 10 +++-
mlir/test/Dialect/Vector/slp-vectorize.mlir | 56 +++++++++++++++++++
2 files changed, 64 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
index 9a389612567df..58c4c5b271458 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp
@@ -1092,8 +1092,14 @@ SLPGraph::vectorize(IRRewriter &rewriter,
Value elem;
if (auto vecType = dyn_cast<VectorType>(originalResultType)) {
- elem = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, newResult, offset, vecType.getNumElements(), 1);
+ assert(vecType.getRank() <= 1);
+ if (vecType.getRank() == 0) {
+ elem = rewriter.create<vector::ExtractOp>(loc, newResult, offset);
+ elem = rewriter.create<vector::SplatOp>(loc, vecType, elem);
+ } else {
+ elem = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, newResult, offset, vecType.getNumElements(), 1);
+ }
} else {
elem = rewriter.create<vector::ExtractOp>(loc, newResult, offset);
}
diff --git a/mlir/test/Dialect/Vector/slp-vectorize.mlir b/mlir/test/Dialect/Vector/slp-vectorize.mlir
index 38490ba4934a4..29c077d7ab34f 100644
--- a/mlir/test/Dialect/Vector/slp-vectorize.mlir
+++ b/mlir/test/Dialect/Vector/slp-vectorize.mlir
@@ -699,6 +699,62 @@ func.func @read_read_add_add_vec(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
}
+// CHECK-LABEL: func @read_read_add_add_vec1
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_add_vec1(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
+ (vector<1xi32>, vector<1xi32>){
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<2xi32>
+ // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
+ // CHECK: %[[V4:.*]] = vector.extract_strided_slice %[[V2]] {offsets = [1], sizes = [1], strides = [1]} : vector<2xi32> to vector<1xi32>
+ // CHECK: return %[[V3]], %[[V4]] : vector<1xi32>, vector<1xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<1xi32>
+ %2 = vector.load %arg0[%c1] : memref<8xi32>, vector<1xi32>
+
+ %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<1xi32>
+ %6 = vector.load %arg1[%c1] : memref<8xi32>, vector<1xi32>
+
+ %8 = arith.addi %0, %4 : vector<1xi32>
+ %10 = arith.addi %2, %6 : vector<1xi32>
+
+ return %8, %10 : vector<1xi32>, vector<1xi32>
+}
+
+
+// CHECK-LABEL: func @read_read_add_add_vec0d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
+func.func @read_read_add_add_vec0d(%arg0: memref<8xi32>, %arg1: memref<8xi32>) ->
+ (vector<i32>, vector<i32>){
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[V0:.*]] = vector.load %[[ARG0]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: %[[V1:.*]] = vector.load %[[ARG1]][%[[C0]]] : memref<8xi32>, vector<2xi32>
+ // CHECK: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]] : vector<2xi32>
+ // CHECK: %[[V3:.*]] = vector.extract %[[V2]][0] : i32 from vector<2xi32>
+ // CHECK: %[[V4:.*]] = vector.splat %[[V3]] : vector<i32>
+ // CHECK: %[[V5:.*]] = vector.extract %[[V2]][1] : i32 from vector<2xi32>
+ // CHECK: %[[V6:.*]] = vector.splat %[[V5]] : vector<i32>
+ // CHECK: return %[[V4]], %[[V6]] : vector<i32>, vector<i32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ %0 = vector.load %arg0[%c0] : memref<8xi32>, vector<i32>
+ %2 = vector.load %arg0[%c1] : memref<8xi32>, vector<i32>
+
+ %4 = vector.load %arg1[%c0] : memref<8xi32>, vector<i32>
+ %6 = vector.load %arg1[%c1] : memref<8xi32>, vector<i32>
+
+ %8 = arith.addi %0, %4 : vector<i32>
+ %10 = arith.addi %2, %6 : vector<i32>
+
+ return %8, %10 : vector<i32>, vector<i32>
+}
+
+
func.func private @use(i32)
// CHECK-LABEL: func @read_read_add_write_interleaved_use
More information about the Mlir-commits
mailing list