[Mlir-commits] [mlir] [mlir][spirv] Fix FuncOpVectorUnroll to process placeholder values in all blocks (PR #142339)
Darren Wihandi
llvmlistbot at llvm.org
Thu Jun 12 21:11:14 PDT 2025
https://github.com/fairywreath updated https://github.com/llvm/llvm-project/pull/142339
>From 393ef3602a60f987e8d14dabeae2b32071504813 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Mon, 2 Jun 2025 02:43:54 -0400
Subject: [PATCH 1/5] [mlir][spirv] Fix FuncOpVectorUnroll to process all
blocks
---
.../SPIRV/Transforms/SPIRVConversion.cpp | 26 ++++++++++---------
1 file changed, 14 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62a24646d0662..84796fdeda03a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,35 +1020,37 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
entryBlock.addArguments(convertedTypes, locs);
- // Replace the placeholder values with the new arguments. We assume there is
- // only one block for now.
+ // Replace the placeholder values with the new arguments.
size_t unrolledInputIdx = 0;
- for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+ newFuncOp.walk([&](Operation *op) {
// We first look for operands that are placeholders for initially legal
// arguments.
- Operation &curOp = op;
- for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+ for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
Operation *operandOp = operandVal.getDefiningOp();
if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
size_t idx = operandIdx;
- rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
- curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+ rewriter.modifyOpInPlace(op, [&] {
+ op->setOperand(idx, newFuncOp.getArgument(it->second));
});
}
}
+
// Since all newly created operations are in the beginning, reaching the
// end of them means that any later `vector.insert_strided_slice` should
// not be touched.
- if (count >= newOpCount)
- continue;
+ if (op->getBlock() == &entryBlock &&
+ static_cast<size_t>(std::distance(entryBlock.begin(),
+ op->getIterator())) >= newOpCount)
+ return;
+
if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
- rewriter.modifyOpInPlace(&curOp, [&] {
- curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+ rewriter.modifyOpInPlace(op, [&] {
+ op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
});
++unrolledInputIdx;
}
- }
+ });
// Erase the original funcOp. The `tmpOps` do not need to be erased since
// they have no uses and will be handled by dead-code elimination.
>From c720c6b9c8ee99c206f0deadd2295c674990b8a3 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Sun, 8 Jun 2025 21:52:27 -0400
Subject: [PATCH 2/5] Fix vec slice replacement?
---
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 84796fdeda03a..04070d50785ba 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1038,7 +1038,7 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
// Since all newly created operations are in the beginning, reaching the
// end of them means that any later `vector.insert_strided_slice` should
// not be touched.
- if (op->getBlock() == &entryBlock &&
+ if (op->getBlock() != &entryBlock ||
static_cast<size_t>(std::distance(entryBlock.begin(),
op->getIterator())) >= newOpCount)
return;
>From 190b9aa1114d2e2cfbec601412dd2fbd1c3c2488 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Mon, 9 Jun 2025 02:11:22 -0400
Subject: [PATCH 3/5] Add tests and update comment
---
.../SPIRV/Transforms/SPIRVConversion.cpp | 8 +-
.../func-signature-vector-unroll.mlir | 73 +++++++++++++++++++
2 files changed, 78 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 04070d50785ba..501a7286a6dbc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1035,9 +1035,11 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
}
}
- // Since all newly created operations are in the beginning, reaching the
- // end of them means that any later `vector.insert_strided_slice` should
- // not be touched.
+ // Only consider `vector.insert_strided_slice` ops that were newly created
+ // at the beginning of the entry block. Once we encounter operations
+ // outside the entry block or past the `newOpCount`-th operation in the
+ // entry block, we skip and leave exisintg `vector.insert_strided_slice`
+ // ops as is.
if (op->getBlock() != &entryBlock ||
static_cast<size_t>(std::distance(entryBlock.begin(),
op->getIterator())) >= newOpCount)
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
index c018ccb924983..211d6c90243bd 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -189,3 +189,76 @@ func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) {
return %arg0 : vector<[8]xi32>
}
+// -----
+
+// Check that already legal function parameters are properly preserved across multiple blocks.
+
+// CHECK-LABEL: func.func @legal_params_multiple_blocks_simple
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) -> i32
+func.func @legal_params_multiple_blocks_simple(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+ // CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
+ // CHECK: return %[[ADD1]] : i32
+ cf.br ^bb1(%arg0 : i32)
+^bb1(%acc0: i32):
+ %acc1_val = arith.addi %acc0, %arg1 : i32
+ cf.br ^bb2(%acc1_val : i32)
+^bb2(%acc1: i32):
+ %acc2_val = arith.addi %acc1, %arg1 : i32
+ cf.br ^bb3(%acc2_val : i32)
+^bb3(%acc_final: i32):
+ return %acc_final : i32
+}
+
+// -----
+
+// Check that legal parameters and existing `vector.insert_strided_slice`s are properly preserved across multiple blocks.
+
+// CHECK-LABEL: func.func @legal_params_with_vec_insert_multiple_blocks
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: vector<4xi32>) -> vector<4xi32>
+func.func @legal_params_with_vec_insert_multiple_blocks(%arg0: i32, %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
+ // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+ // CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
+ // CHECK: %[[VEC1D:.*]] = vector.broadcast %[[ADD1]] : i32 to vector<1xi32>
+ // CHECK: %[[VEC0:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[ARG2]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
+ // CHECK: %[[VEC1:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC0]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
+ // CHECK: %[[RESULT:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC1]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
+ // CHECK: return %[[RESULT]] : vector<4xi32>
+ cf.br ^bb1(%arg0 : i32)
+^bb1(%acc0: i32):
+ %acc1_val = arith.addi %acc0, %arg1 : i32
+ cf.br ^bb2(%acc1_val : i32)
+^bb2(%acc1: i32):
+ %acc2_val = arith.addi %acc1, %arg1 : i32
+ cf.br ^bb3(%acc2_val : i32)
+^bb3(%acc_final: i32):
+ %scalar_vec = vector.broadcast %acc_final : i32 to vector<1xi32>
+ %vec0 = vector.insert_strided_slice %scalar_vec, %arg2 {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
+ %vec1 = vector.insert_strided_slice %scalar_vec, %vec0 {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
+ %result = vector.insert_strided_slice %scalar_vec, %vec1 {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
+ return %result : vector<4xi32>
+}
+
+// -----
+
+// Check that already legal function parameters are preserved across a loop (which contains multiple blocks).
+
+// CHECK-LABEL: @legal_params_for_loop
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
+func.func @legal_params_for_loop(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CST1:.*]] = arith.constant 1 : index
+ // CHECK: %[[UB:.*]] = arith.index_cast %[[ARG2]] : i32 to index
+ // CHECK: %[[RESULT:.*]] = scf.for %[[STEP:.*]] = %[[CST0]] to %[[UB]] step %[[CST1]] iter_args(%[[ACC:.*]] = %[[ARG0]]) -> (i32) {
+ // CHECK: %[[ADD:.*]] = arith.addi %[[ACC]], %[[ARG1]] : i32
+ // CHECK: scf.yield %[[ADD]] : i32
+ // CHECK: return %[[RESULT]] : i32
+ %zero = arith.constant 0 : index
+ %one = arith.constant 1 : index
+ %ub = arith.index_cast %arg2 : i32 to index
+ %result = scf.for %i = %zero to %ub step %one iter_args(%acc = %arg0) -> (i32) {
+ %new_acc = arith.addi %acc, %arg1 : i32
+ scf.yield %new_acc : i32
+ }
+ return %result : i32
+}
>From e496a3e3374b0d6fb4dc899b204c73e6dc076f94 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Tue, 10 Jun 2025 02:16:21 -0400
Subject: [PATCH 4/5] Improve implementation to note use walk
---
.../SPIRV/Transforms/SPIRVConversion.cpp | 48 ++++++++-----------
1 file changed, 21 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 501a7286a6dbc..ee4271263bb26 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,39 +1020,33 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
entryBlock.addArguments(convertedTypes, locs);
- // Replace the placeholder values with the new arguments.
- size_t unrolledInputIdx = 0;
- newFuncOp.walk([&](Operation *op) {
- // We first look for operands that are placeholders for initially legal
- // arguments.
- for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
- Operation *operandOp = operandVal.getDefiningOp();
- if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
- size_t idx = operandIdx;
- rewriter.modifyOpInPlace(op, [&] {
- op->setOperand(idx, newFuncOp.getArgument(it->second));
- });
- }
- }
-
- // Only consider `vector.insert_strided_slice` ops that were newly created
- // at the beginning of the entry block. Once we encounter operations
- // outside the entry block or past the `newOpCount`-th operation in the
- // entry block, we skip and leave exisintg `vector.insert_strided_slice`
- // ops as is.
- if (op->getBlock() != &entryBlock ||
- static_cast<size_t>(std::distance(entryBlock.begin(),
- op->getIterator())) >= newOpCount)
- return;
+ // Replace all uses of placeholders for initially legal arguments with their
+ // original function arguments (that were added to `newFuncOp`).
+ for (auto &[placeholderOp, argIdx] : tmpOps) {
+ if (!placeholderOp)
+ continue;
+ Value replacement = newFuncOp.getArgument(argIdx);
+ rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
+ }
+ // Replace dummy operands of new `vector.insert_strided_slice` ops with
+ // their corresponding new function arguments.
+ size_t unrolledInputIdx = 0;
+ for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+ Operation &curOp = op;
+ // Since all newly created operations are in the beginning, reaching the
+ // end of them means that any later `vector.insert_strided_slice` should
+ // not be touched.
+ if (count >= newOpCount)
+ continue;
if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
- rewriter.modifyOpInPlace(op, [&] {
- op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+ rewriter.modifyOpInPlace(&curOp, [&] {
+ curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
});
++unrolledInputIdx;
}
- });
+ }
// Erase the original funcOp. The `tmpOps` do not need to be erased since
// they have no uses and will be handled by dead-code elimination.
>From 4f2616da5910713004374c5283147d65f0cabb70 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Fri, 13 Jun 2025 00:10:25 -0400
Subject: [PATCH 5/5] Add comment on why iterating only the entry block for
vector insert replacement is OK
---
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index ee4271263bb26..f5a58c58e05df 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1030,7 +1030,9 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
}
// Replace dummy operands of new `vector.insert_strided_slice` ops with
- // their corresponding new function arguments.
+ // their corresponding new function arguments. The new
+ // `vector.insert_strided_slice` ops are inserted only into the entry block,
+ // so iterating over that block is sufficient.
size_t unrolledInputIdx = 0;
for (auto [count, op] : enumerate(entryBlock.getOperations())) {
Operation &curOp = op;
More information about the Mlir-commits
mailing list