[Mlir-commits] [mlir] [mlir][spirv] Fix FuncOpVectorUnroll to process placeholder values in all blocks (PR #142339)

Darren Wihandi llvmlistbot at llvm.org
Thu Jun 12 21:11:14 PDT 2025


https://github.com/fairywreath updated https://github.com/llvm/llvm-project/pull/142339

>From 393ef3602a60f987e8d14dabeae2b32071504813 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Mon, 2 Jun 2025 02:43:54 -0400
Subject: [PATCH 1/5] [mlir][spirv] Fix FuncOpVectorUnroll to process all
 blocks

---
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 26 ++++++++++---------
 1 file changed, 14 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62a24646d0662..84796fdeda03a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,35 +1020,37 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
     SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
     entryBlock.addArguments(convertedTypes, locs);
 
-    // Replace the placeholder values with the new arguments. We assume there is
-    // only one block for now.
+    // Replace the placeholder values with the new arguments.
     size_t unrolledInputIdx = 0;
-    for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+    newFuncOp.walk([&](Operation *op) {
       // We first look for operands that are placeholders for initially legal
       // arguments.
-      Operation &curOp = op;
-      for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+      for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
         Operation *operandOp = operandVal.getDefiningOp();
         if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
           size_t idx = operandIdx;
-          rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
-            curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+          rewriter.modifyOpInPlace(op, [&] {
+            op->setOperand(idx, newFuncOp.getArgument(it->second));
           });
         }
       }
+
       // Since all newly created operations are in the beginning, reaching the
       // end of them means that any later `vector.insert_strided_slice` should
       // not be touched.
-      if (count >= newOpCount)
-        continue;
+      if (op->getBlock() == &entryBlock &&
+          static_cast<size_t>(std::distance(entryBlock.begin(),
+                                            op->getIterator())) >= newOpCount)
+        return;
+
       if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
         size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
-        rewriter.modifyOpInPlace(&curOp, [&] {
-          curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+        rewriter.modifyOpInPlace(op, [&] {
+          op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
         });
         ++unrolledInputIdx;
       }
-    }
+    });
 
     // Erase the original funcOp. The `tmpOps` do not need to be erased since
     // they have no uses and will be handled by dead-code elimination.

>From c720c6b9c8ee99c206f0deadd2295c674990b8a3 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Sun, 8 Jun 2025 21:52:27 -0400
Subject: [PATCH 2/5] Fix vec slice replacement?

---
 mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 84796fdeda03a..04070d50785ba 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1038,7 +1038,7 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
       // Since all newly created operations are in the beginning, reaching the
       // end of them means that any later `vector.insert_strided_slice` should
       // not be touched.
-      if (op->getBlock() == &entryBlock &&
+      if (op->getBlock() != &entryBlock ||
           static_cast<size_t>(std::distance(entryBlock.begin(),
                                             op->getIterator())) >= newOpCount)
         return;

>From 190b9aa1114d2e2cfbec601412dd2fbd1c3c2488 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Mon, 9 Jun 2025 02:11:22 -0400
Subject: [PATCH 3/5] Add tests and update comment

---
 .../SPIRV/Transforms/SPIRVConversion.cpp      |  8 +-
 .../func-signature-vector-unroll.mlir         | 73 +++++++++++++++++++
 2 files changed, 78 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 04070d50785ba..501a7286a6dbc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1035,9 +1035,11 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
         }
       }
 
-      // Since all newly created operations are in the beginning, reaching the
-      // end of them means that any later `vector.insert_strided_slice` should
-      // not be touched.
+      // Only consider `vector.insert_strided_slice` ops that were newly created
+      // at the beginning of the entry block. Once we encounter operations
+      // outside the entry block or past the `newOpCount`-th operation in the
+      // entry block, we skip and leave exisintg `vector.insert_strided_slice`
+      // ops as is.
       if (op->getBlock() != &entryBlock ||
           static_cast<size_t>(std::distance(entryBlock.begin(),
                                             op->getIterator())) >= newOpCount)
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
index c018ccb924983..211d6c90243bd 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -189,3 +189,76 @@ func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) {
   return %arg0 : vector<[8]xi32>
 }
 
+// -----
+
+// Check that already legal function parameters are properly preserved across multiple blocks.
+
+// CHECK-LABEL: func.func @legal_params_multiple_blocks_simple
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) -> i32
+func.func @legal_params_multiple_blocks_simple(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+  // CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
+  // CHECK: return %[[ADD1]] : i32
+  cf.br ^bb1(%arg0 : i32)
+^bb1(%acc0: i32):
+  %acc1_val = arith.addi %acc0, %arg1 : i32
+  cf.br ^bb2(%acc1_val : i32)
+^bb2(%acc1: i32):
+  %acc2_val = arith.addi %acc1, %arg1 : i32
+  cf.br ^bb3(%acc2_val : i32)
+^bb3(%acc_final: i32):
+  return %acc_final : i32
+}
+
+// -----
+
+// Check that legal parameters and existing `vector.insert_strided_slice`s are properly preserved across multiple blocks.
+
+// CHECK-LABEL: func.func @legal_params_with_vec_insert_multiple_blocks
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: vector<4xi32>) -> vector<4xi32>
+func.func @legal_params_with_vec_insert_multiple_blocks(%arg0: i32, %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+  // CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
+  // CHECK: %[[VEC1D:.*]] = vector.broadcast %[[ADD1]] : i32 to vector<1xi32>
+  // CHECK: %[[VEC0:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[ARG2]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
+  // CHECK: %[[VEC1:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC0]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
+  // CHECK: %[[RESULT:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC1]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
+  // CHECK: return %[[RESULT]] : vector<4xi32>
+  cf.br ^bb1(%arg0 : i32)
+^bb1(%acc0: i32):
+  %acc1_val = arith.addi %acc0, %arg1 : i32
+  cf.br ^bb2(%acc1_val : i32)
+^bb2(%acc1: i32):
+  %acc2_val = arith.addi %acc1, %arg1 : i32
+  cf.br ^bb3(%acc2_val : i32)
+^bb3(%acc_final: i32):
+  %scalar_vec = vector.broadcast %acc_final : i32 to vector<1xi32>
+  %vec0 = vector.insert_strided_slice %scalar_vec, %arg2 {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
+  %vec1 = vector.insert_strided_slice %scalar_vec, %vec0 {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
+  %result = vector.insert_strided_slice %scalar_vec, %vec1 {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
+  return %result : vector<4xi32>
+}
+
+// -----
+
+// Check that already legal function parameters are preserved across a loop (which contains multiple blocks).
+
+// CHECK-LABEL: @legal_params_for_loop
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
+func.func @legal_params_for_loop(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CST1:.*]] = arith.constant 1 : index
+  // CHECK: %[[UB:.*]] = arith.index_cast %[[ARG2]] : i32 to index
+  // CHECK: %[[RESULT:.*]] = scf.for %[[STEP:.*]] = %[[CST0]] to %[[UB]] step %[[CST1]] iter_args(%[[ACC:.*]] = %[[ARG0]]) -> (i32) {
+  // CHECK:   %[[ADD:.*]] = arith.addi %[[ACC]], %[[ARG1]] : i32
+  // CHECK:   scf.yield %[[ADD]] : i32
+  // CHECK: return %[[RESULT]] : i32
+  %zero = arith.constant 0 : index
+  %one = arith.constant 1 : index
+  %ub = arith.index_cast %arg2 : i32 to index
+  %result = scf.for %i = %zero to %ub step %one iter_args(%acc = %arg0) -> (i32) {
+    %new_acc = arith.addi %acc, %arg1 : i32
+    scf.yield %new_acc : i32
+  }
+  return %result : i32
+}

>From e496a3e3374b0d6fb4dc899b204c73e6dc076f94 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Tue, 10 Jun 2025 02:16:21 -0400
Subject: [PATCH 4/5] Improve implementation to note use walk

---
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 48 ++++++++-----------
 1 file changed, 21 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 501a7286a6dbc..ee4271263bb26 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,39 +1020,33 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
     SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
     entryBlock.addArguments(convertedTypes, locs);
 
-    // Replace the placeholder values with the new arguments.
-    size_t unrolledInputIdx = 0;
-    newFuncOp.walk([&](Operation *op) {
-      // We first look for operands that are placeholders for initially legal
-      // arguments.
-      for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
-        Operation *operandOp = operandVal.getDefiningOp();
-        if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
-          size_t idx = operandIdx;
-          rewriter.modifyOpInPlace(op, [&] {
-            op->setOperand(idx, newFuncOp.getArgument(it->second));
-          });
-        }
-      }
-
-      // Only consider `vector.insert_strided_slice` ops that were newly created
-      // at the beginning of the entry block. Once we encounter operations
-      // outside the entry block or past the `newOpCount`-th operation in the
-      // entry block, we skip and leave exisintg `vector.insert_strided_slice`
-      // ops as is.
-      if (op->getBlock() != &entryBlock ||
-          static_cast<size_t>(std::distance(entryBlock.begin(),
-                                            op->getIterator())) >= newOpCount)
-        return;
+    // Replace all uses of placeholders for initially legal arguments with their
+    // original function arguments (that were added to `newFuncOp`).
+    for (auto &[placeholderOp, argIdx] : tmpOps) {
+      if (!placeholderOp)
+        continue;
+      Value replacement = newFuncOp.getArgument(argIdx);
+      rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
+    }
 
+    // Replace dummy operands of new `vector.insert_strided_slice` ops with
+    // their corresponding new function arguments.
+    size_t unrolledInputIdx = 0;
+    for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+      Operation &curOp = op;
+      // Since all newly created operations are in the beginning, reaching the
+      // end of them means that any later `vector.insert_strided_slice` should
+      // not be touched.
+      if (count >= newOpCount)
+        continue;
       if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
         size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
-        rewriter.modifyOpInPlace(op, [&] {
-          op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+        rewriter.modifyOpInPlace(&curOp, [&] {
+          curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
         });
         ++unrolledInputIdx;
       }
-    });
+    }
 
     // Erase the original funcOp. The `tmpOps` do not need to be erased since
     // they have no uses and will be handled by dead-code elimination.

>From 4f2616da5910713004374c5283147d65f0cabb70 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Fri, 13 Jun 2025 00:10:25 -0400
Subject: [PATCH 5/5] Add comment on why iterating only the entry block for
 vector insert replacement is OK

---
 mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index ee4271263bb26..f5a58c58e05df 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1030,7 +1030,9 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
     }
 
     // Replace dummy operands of new `vector.insert_strided_slice` ops with
-    // their corresponding new function arguments.
+    // their corresponding new function arguments. The new
+    // `vector.insert_strided_slice` ops are inserted only into the entry block,
+    // so iterating over that block is sufficient.
     size_t unrolledInputIdx = 0;
     for (auto [count, op] : enumerate(entryBlock.getOperations())) {
       Operation &curOp = op;



More information about the Mlir-commits mailing list