[Mlir-commits] [mlir] [mlir][vector] Add support for yielding loop bounds in `scf.for` distribution. (PR #163443)

Charitha Saumya llvmlistbot at llvm.org
Tue Oct 14 13:06:19 PDT 2025


https://github.com/charithaintc created https://github.com/llvm/llvm-project/pull/163443

In some cases, loop bounds (lower, upper and step) of `scf.for` can come locally from the parent warp op the `scf.for`. Current logic will not yield the loop bounds in the new warp op generated during lowering causing sinked `scf.for` to have non dominating use.

In this PR, we have added logic to yield loop bounds by default (treat them as other operands of `scf.for`) which fixes this bug. 

>From d9ad36dfc4c47e64e2ca0170f13569c3e904dd79 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 14 Oct 2025 20:01:35 +0000
Subject: [PATCH] add fix

---
 .../Vector/Transforms/VectorDistribute.cpp    | 25 +++++++++----
 .../Vector/vector-warp-distribute.mlir        | 35 +++++++++++++++++++
 2 files changed, 53 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f7d18be..2ee65dc0f902a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2038,11 +2038,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     }
 
     // Newly created `WarpOp` will yield values in following order:
-    // 1. All init args of the `ForOp`.
-    // 2. All escaping values.
-    // 3. All non-`ForOp` yielded values.
+    // 1. Loop bounds.
+    // 2. All init args of the `ForOp`.
+    // 3. All escaping values.
+    // 4. All non-`ForOp` yielded values.
     SmallVector<Value> newWarpOpYieldValues;
     SmallVector<Type> newWarpOpDistTypes;
+    newWarpOpYieldValues.insert(
+        newWarpOpYieldValues.end(),
+        {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
+    newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+                              {forOp.getLowerBound().getType(),
+                               forOp.getUpperBound().getType(),
+                               forOp.getStep().getType()});
     for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
       newWarpOpYieldValues.push_back(initArg);
       // Compute the distributed type for this init arg.
@@ -2081,20 +2089,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
 
     // Next, we create a new `ForOp` with the init args yielded by the new
     // `WarpOp`.
+    const unsigned initArgsStartIdx = 3; // After loop bounds.
     const unsigned escapingValuesStartIdx =
+        initArgsStartIdx +
         forOp.getInitArgs().size(); // `ForOp` init args are positioned before
                                     // escaping values in the new `WarpOp`.
     SmallVector<Value> newForOpOperands;
-    for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+    for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
       newForOpOperands.push_back(newWarpOp.getResult(i));
 
     // Create a new `ForOp` outside the new `WarpOp` region.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
     auto newForOp = scf::ForOp::create(
-        rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
-        forOp.getUnsignedCmp());
+        rewriter, forOp.getLoc(), /**LowerBound=**/ newWarpOp.getResult(0),
+        /**UpperBound=**/ newWarpOp.getResult(1),
+        /**Step=**/ newWarpOp.getResult(2), newForOpOperands,
+        /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
     // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
     // newly created `ForOp`. This `WarpOp` will contain all ops that were
     // contained within the original `ForOp` body.
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bb7639204022f..ab87684dbb01a 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -473,6 +473,41 @@ func.func @warp_scf_for_use_from_above(%arg0: index) {
   return
 }
 
+// -----
+// CHECK-PROP-LABEL:  func.func @warp_scf_for_local_loop_bounds
+// CHECK-PROP:          (%{{.*}}: index, %[[ARG1:[a-zA-Z0-9]+]]: index) {
+// CHECK-PROP:          %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG1]] : index) -> (vector<4xf32>) {
+// CHECK-PROP:          ^bb0(%{{.*}}: index):
+// CHECK-PROP:            %[[T2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP:            gpu.yield %[[T2]] : vector<128xf32>
+// CHECK-PROP:          }
+// CHECK-PROP:          %[[FOR:.*]] = scf.for %{{.*}} to %[[ARG1]] step %{{.*}} iter_args(%{{.*}}) -> (vector<4xf32>) {
+// CHECK-PROP:            %[[W2:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32]
+// CHECK-PROP-SAME:          args(%{{.*}} : vector<4xf32>) -> (vector<4xf32>) {
+// CHECK-PROP:            ^bb0(%{{.*}}: vector<128xf32>):
+// CHECK-PROP:              gpu.yield %{{.*}} : vector<128xf32>
+// CHECK-PROP:            }
+// CHECK-PROP:            scf.yield %[[W2]] : vector<4xf32>
+// CHECK-PROP:          }
+// CHECK-PROP:          "some_use"(%[[FOR]]) : (vector<4xf32>) -> ()
+// CHECK-PROP:          return
+func.func @warp_scf_for_local_loop_bounds(%arg0: index, %bound: index) {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[32]
+    args(%bound : index) -> (vector<4xf32>) {
+    ^bb0(%arg1: index):
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %3 = scf.for %arg3 = %c0 to %arg1 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) {
+      %acc = "some_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc : vector<128xf32>
+    }
+    gpu.yield %3 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
 // -----
 
 // CHECK-PROP-LABEL:   func @warp_scf_for_swap(



More information about the Mlir-commits mailing list