[Mlir-commits] [mlir] 9e62298 - [mlir][spirv] Fix FuncOpVectorUnroll to process placeholder values in all blocks (#142339)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 13 08:06:35 PDT 2025


Author: Darren Wihandi
Date: 2025-06-13T11:06:31-04:00
New Revision: 9e622986526a35f3f8bc60a7fc756b5c7bf825c0

URL: https://github.com/llvm/llvm-project/commit/9e622986526a35f3f8bc60a7fc756b5c7bf825c0
DIFF: https://github.com/llvm/llvm-project/commit/9e622986526a35f3f8bc60a7fc756b5c7bf825c0.diff

LOG: [mlir][spirv] Fix FuncOpVectorUnroll to process placeholder values in all blocks (#142339)

`FuncOpVectorUnroll` contains logic that replaces function arguments by
placeholders values. These replacements also involve changing all
instructions in the function that use the arguments to use these
placeholders. These placeholder values will later be changed back to use
the function arguments (either new or original if already legal).

The current implementation however only replaces back (the second
replacement, i.e. replacing the placeholder values to new/legal
arguments) the first block of instructions and not all of the blocks.
This may leave some instructions to use these placeholder values (which
for already legal arguments are just zeroattr values that will get
DCE'd) instead of the arguments, which is incorrect.

Closes #132158.

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62a24646d0662..f5a58c58e05df 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,22 +1020,22 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
     SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
     entryBlock.addArguments(convertedTypes, locs);
 
-    // Replace the placeholder values with the new arguments. We assume there is
-    // only one block for now.
+    // Replace all uses of placeholders for initially legal arguments with their
+    // original function arguments (that were added to `newFuncOp`).
+    for (auto &[placeholderOp, argIdx] : tmpOps) {
+      if (!placeholderOp)
+        continue;
+      Value replacement = newFuncOp.getArgument(argIdx);
+      rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
+    }
+
+    // Replace dummy operands of new `vector.insert_strided_slice` ops with
+    // their corresponding new function arguments. The new
+    // `vector.insert_strided_slice` ops are inserted only into the entry block,
+    // so iterating over that block is sufficient.
     size_t unrolledInputIdx = 0;
     for (auto [count, op] : enumerate(entryBlock.getOperations())) {
-      // We first look for operands that are placeholders for initially legal
-      // arguments.
       Operation &curOp = op;
-      for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
-        Operation *operandOp = operandVal.getDefiningOp();
-        if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
-          size_t idx = operandIdx;
-          rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
-            curOp.setOperand(idx, newFuncOp.getArgument(it->second));
-          });
-        }
-      }
       // Since all newly created operations are in the beginning, reaching the
       // end of them means that any later `vector.insert_strided_slice` should
       // not be touched.

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
index c018ccb924983..211d6c90243bd 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -189,3 +189,76 @@ func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) {
   return %arg0 : vector<[8]xi32>
 }
 
+// -----
+
+// Check that already legal function parameters are properly preserved across multiple blocks.
+
+// CHECK-LABEL: func.func @legal_params_multiple_blocks_simple
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) -> i32
+func.func @legal_params_multiple_blocks_simple(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+  // CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
+  // CHECK: return %[[ADD1]] : i32
+  cf.br ^bb1(%arg0 : i32)
+^bb1(%acc0: i32):
+  %acc1_val = arith.addi %acc0, %arg1 : i32
+  cf.br ^bb2(%acc1_val : i32)
+^bb2(%acc1: i32):
+  %acc2_val = arith.addi %acc1, %arg1 : i32
+  cf.br ^bb3(%acc2_val : i32)
+^bb3(%acc_final: i32):
+  return %acc_final : i32
+}
+
+// -----
+
+// Check that legal parameters and existing `vector.insert_strided_slice`s are properly preserved across multiple blocks.
+
+// CHECK-LABEL: func.func @legal_params_with_vec_insert_multiple_blocks
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: vector<4xi32>) -> vector<4xi32>
+func.func @legal_params_with_vec_insert_multiple_blocks(%arg0: i32, %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+  // CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[ARG1]] : i32
+  // CHECK: %[[VEC1D:.*]] = vector.broadcast %[[ADD1]] : i32 to vector<1xi32>
+  // CHECK: %[[VEC0:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[ARG2]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
+  // CHECK: %[[VEC1:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC0]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
+  // CHECK: %[[RESULT:.*]] = vector.insert_strided_slice %[[VEC1D]], %[[VEC1]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
+  // CHECK: return %[[RESULT]] : vector<4xi32>
+  cf.br ^bb1(%arg0 : i32)
+^bb1(%acc0: i32):
+  %acc1_val = arith.addi %acc0, %arg1 : i32
+  cf.br ^bb2(%acc1_val : i32)
+^bb2(%acc1: i32):
+  %acc2_val = arith.addi %acc1, %arg1 : i32
+  cf.br ^bb3(%acc2_val : i32)
+^bb3(%acc_final: i32):
+  %scalar_vec = vector.broadcast %acc_final : i32 to vector<1xi32>
+  %vec0 = vector.insert_strided_slice %scalar_vec, %arg2 {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
+  %vec1 = vector.insert_strided_slice %scalar_vec, %vec0 {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
+  %result = vector.insert_strided_slice %scalar_vec, %vec1 {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
+  return %result : vector<4xi32>
+}
+
+// -----
+
+// Check that already legal function parameters are preserved across a loop (which contains multiple blocks).
+
+// CHECK-LABEL: @legal_params_for_loop
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
+func.func @legal_params_for_loop(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CST1:.*]] = arith.constant 1 : index
+  // CHECK: %[[UB:.*]] = arith.index_cast %[[ARG2]] : i32 to index
+  // CHECK: %[[RESULT:.*]] = scf.for %[[STEP:.*]] = %[[CST0]] to %[[UB]] step %[[CST1]] iter_args(%[[ACC:.*]] = %[[ARG0]]) -> (i32) {
+  // CHECK:   %[[ADD:.*]] = arith.addi %[[ACC]], %[[ARG1]] : i32
+  // CHECK:   scf.yield %[[ADD]] : i32
+  // CHECK: return %[[RESULT]] : i32
+  %zero = arith.constant 0 : index
+  %one = arith.constant 1 : index
+  %ub = arith.index_cast %arg2 : i32 to index
+  %result = scf.for %i = %zero to %ub step %one iter_args(%acc = %arg0) -> (i32) {
+    %new_acc = arith.addi %acc, %arg1 : i32
+    scf.yield %new_acc : i32
+  }
+  return %result : i32
+}


        


More information about the Mlir-commits mailing list