[Mlir-commits] [mlir] f7a5264 - [mlir][vector] Add support for yielding loop bounds in `scf.for` distribution. (#163443)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 17 09:07:21 PDT 2025


Author: Charitha Saumya
Date: 2025-10-17T09:07:17-07:00
New Revision: f7a5264890fe050124cd576410695a7c90c4d8d8

URL: https://github.com/llvm/llvm-project/commit/f7a5264890fe050124cd576410695a7c90c4d8d8
DIFF: https://github.com/llvm/llvm-project/commit/f7a5264890fe050124cd576410695a7c90c4d8d8.diff

LOG: [mlir][vector] Add support for yielding loop bounds in `scf.for` distribution.  (#163443)

In some cases, loop bounds (lower, upper and step) of `scf.for` can come
locally from the parent warp op the `scf.for`. Current logic will not
yield the loop bounds in the new warp op generated during lowering
causing sinked `scf.for` to have non dominating use.

In this PR, we have added logic to yield loop bounds by default (treat
them as other operands of `scf.for`) which fixes this bug.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 12e6475fa66e3..7c019e7d25bf2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2032,11 +2032,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     }
 
     // Newly created `WarpOp` will yield values in following order:
-    // 1. All init args of the `ForOp`.
-    // 2. All escaping values.
-    // 3. All non-`ForOp` yielded values.
+    // 1. Loop bounds.
+    // 2. All init args of the `ForOp`.
+    // 3. All escaping values.
+    // 4. All non-`ForOp` yielded values.
     SmallVector<Value> newWarpOpYieldValues;
     SmallVector<Type> newWarpOpDistTypes;
+    newWarpOpYieldValues.insert(
+        newWarpOpYieldValues.end(),
+        {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
+    newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+                              {forOp.getLowerBound().getType(),
+                               forOp.getUpperBound().getType(),
+                               forOp.getStep().getType()});
     for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
       newWarpOpYieldValues.push_back(initArg);
       // Compute the distributed type for this init arg.
@@ -2072,20 +2080,24 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
 
     // Next, we create a new `ForOp` with the init args yielded by the new
     // `WarpOp`.
+    const unsigned initArgsStartIdx = 3; // After loop bounds.
     const unsigned escapingValuesStartIdx =
+        initArgsStartIdx +
         forOp.getInitArgs().size(); // `ForOp` init args are positioned before
                                     // escaping values in the new `WarpOp`.
     SmallVector<Value> newForOpOperands;
-    for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+    for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
       newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
 
     // Create a new `ForOp` outside the new `WarpOp` region.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
     auto newForOp = scf::ForOp::create(
-        rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
-        forOp.getUnsignedCmp());
+        rewriter, forOp.getLoc(),
+        /**LowerBound=**/ newWarpOp.getResult(newIndices[0]),
+        /**UpperBound=**/ newWarpOp.getResult(newIndices[1]),
+        /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands,
+        /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
     // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
     // newly created `ForOp`. This `WarpOp` will contain all ops that were
     // contained within the original `ForOp` body.

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 401cdd29b281c..0cf6dd151e16c 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) {
   return
 }
 
+// -----
+// CHECK-PROP-LABEL:  func.func @warp_scf_for_local_loop_bounds
+// CHECK-PROP:          (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) {
+// CHECK-PROP:          %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) {
+// CHECK-PROP:          ^bb0(%{{.*}}: index):
+// CHECK-PROP:            %[[T2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP:            gpu.yield %[[T2]] : vector<128xf32>
+// CHECK-PROP:          }
+// CHECK-PROP:          %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) {
+// CHECK-PROP:            %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32]
+// CHECK-PROP-SAME:          args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP:            ^bb0(%{{.*}}: vector<128xf32>):
+// CHECK-PROP:              gpu.yield %{{.*}} : vector<128xf32>
+// CHECK-PROP:            }
+// CHECK-PROP:            scf.yield %[[W2]] : vector<4xf32>
+// CHECK-PROP:          }
+// CHECK-PROP:          "some_use"(%[[FOR]]) : (vector<4xf32>) -> ()
+// CHECK-PROP:          return
+func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[32]
+    args(%bound : index) -> (vector<4xf32>) {
+    ^bb0(%arg1: index):
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+      %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc : vector<128xf32>
+    }
+    gpu.yield %3 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
 // -----
 
 // CHECK-PROP-LABEL:   func @warp_scf_for_swap(


        


More information about the Mlir-commits mailing list