[Mlir-commits] [mlir] [mlir][spirv] Fix FuncOpVectorUnroll to process placeholder values in all blocks (PR #142339)

Darren Wihandi llvmlistbot at llvm.org
Mon Jun 2 00:09:09 PDT 2025


https://github.com/fairywreath created https://github.com/llvm/llvm-project/pull/142339

`FuncOpVectorUnroll` contains logic that replaces function arguments by placeholders values. These replacements also involve changing all instructions in the function that use the arguments to use these placeholders. These placeholder values will later be changed back to use the function arguments (either new or original if already legal).

The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect.

Closes #132158.

TODO: add test

>From 393ef3602a60f987e8d14dabeae2b32071504813 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Mon, 2 Jun 2025 02:43:54 -0400
Subject: [PATCH] [mlir][spirv] Fix FuncOpVectorUnroll to process all blocks

---
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 26 ++++++++++---------
 1 file changed, 14 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62a24646d0662..84796fdeda03a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,35 +1020,37 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
     SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
     entryBlock.addArguments(convertedTypes, locs);
 
-    // Replace the placeholder values with the new arguments. We assume there is
-    // only one block for now.
+    // Replace the placeholder values with the new arguments.
     size_t unrolledInputIdx = 0;
-    for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+    newFuncOp.walk([&](Operation *op) {
       // We first look for operands that are placeholders for initially legal
       // arguments.
-      Operation &curOp = op;
-      for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+      for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
         Operation *operandOp = operandVal.getDefiningOp();
         if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
           size_t idx = operandIdx;
-          rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
-            curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+          rewriter.modifyOpInPlace(op, [&] {
+            op->setOperand(idx, newFuncOp.getArgument(it->second));
           });
         }
       }
+
       // Since all newly created operations are in the beginning, reaching the
       // end of them means that any later `vector.insert_strided_slice` should
       // not be touched.
-      if (count >= newOpCount)
-        continue;
+      if (op->getBlock() == &entryBlock &&
+          static_cast<size_t>(std::distance(entryBlock.begin(),
+                                            op->getIterator())) >= newOpCount)
+        return;
+
       if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
         size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
-        rewriter.modifyOpInPlace(&curOp, [&] {
-          curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+        rewriter.modifyOpInPlace(op, [&] {
+          op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
         });
         ++unrolledInputIdx;
       }
-    }
+    });
 
     // Erase the original funcOp. The `tmpOps` do not need to be erased since
     // they have no uses and will be handled by dead-code elimination.



More information about the Mlir-commits mailing list