[Mlir-commits] [mlir] Revert "[mlir][vector] Fix for WarpOpScfForOp failure when scf.for has results that are unused." (PR #144124)

Charitha Saumya llvmlistbot at llvm.org
Fri Jun 13 10:22:03 PDT 2025


https://github.com/charithaintc created https://github.com/llvm/llvm-project/pull/144124

Reverts llvm/llvm-project#141853

Reverting the bug fix because it does not handle all cases correctly. 

>From 2238fd9a756ae1a0b6aa2302e96cc217b08d6c3b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <136391709+charithaintc at users.noreply.github.com>
Date: Fri, 13 Jun 2025 10:18:24 -0700
Subject: [PATCH] =?UTF-8?q?Revert=20"[mlir][vector]=20Fix=20for=20WarpOpSc?=
 =?UTF-8?q?fForOp=20failure=20when=20scf.for=20has=20result=E2=80=A6"?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This reverts commit 10dc8bc519130f491d70318bd8b47555307cdc3f.
---
 .../Vector/Transforms/VectorDistribute.cpp    | 39 +++++--------------
 .../Vector/vector-warp-distribute.mlir        | 36 -----------------
 2 files changed, 10 insertions(+), 65 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 52a9cedb43cc0..045c192787f10 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1554,36 +1554,22 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     llvm::SmallSetVector<Value, 32> escapingValues;
     SmallVector<Type> inputTypes;
     SmallVector<Type> distTypes;
-    auto collectEscapingValues = [&](Value value) {
-      if (!escapingValues.insert(value))
-        return;
-      Type distType = value.getType();
-      if (auto vecType = dyn_cast<VectorType>(distType)) {
-        AffineMap map = distributionMapFn(value);
-        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
-      }
-      inputTypes.push_back(value.getType());
-      distTypes.push_back(distType);
-    };
-
     mlir::visitUsedValuesDefinedAbove(
         forOp.getBodyRegion(), [&](OpOperand *operand) {
           Operation *parent = operand->get().getParentRegion()->getParentOp();
           if (warpOp->isAncestor(parent)) {
-            collectEscapingValues(operand->get());
+            if (!escapingValues.insert(operand->get()))
+              return;
+            Type distType = operand->get().getType();
+            if (auto vecType = dyn_cast<VectorType>(distType)) {
+              AffineMap map = distributionMapFn(operand->get());
+              distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+            }
+            inputTypes.push_back(operand->get().getType());
+            distTypes.push_back(distType);
           }
         });
 
-    // Any forOp result that is not already yielded by the warpOp
-    // region is also considered escaping and must be returned by the
-    // original warpOp.
-    for (OpResult forResult : forOp.getResults()) {
-      // Check if this forResult is already yielded by the yield op.
-      if (llvm::is_contained(yield->getOperands(), forResult))
-        continue;
-      collectEscapingValues(forResult);
-    }
-
     if (llvm::is_contained(distTypes, Type{}))
       return failure();
 
@@ -1623,12 +1609,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                                     forOp.getResultTypes().end());
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
     for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
-      auto newWarpResult = newWarpOp.getResult(retIdx);
-      // Unused forOp results yielded by the warpOp region are already included
-      // in the new ForOp.
-      if (llvm::is_contained(newOperands, newWarpResult))
-        continue;
-      warpInput.push_back(newWarpResult);
+      warpInput.push_back(newWarpOp.getResult(retIdx));
       argIndexMapping[escapingValues[i]] = warpInputType.size();
       warpInputType.push_back(inputTypes[i]);
     }
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 6c7ac7a5196a7..38771f2593449 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,42 +584,6 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
   return
 }
 
-// -----
-// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_yield(
-//       CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
-//       CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
-//       CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
-//       CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
-//       CHECK-PROP: }
-//       CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
-//       CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
-//       CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
-//       CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
-//       CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
-//       CHECK-PROP: }
-//       CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
-//       CHECK-PROP: }
-//       CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
-func.func @warp_scf_for_unused_yield(%arg0: index) {
-  %c128 = arith.constant 128 : index
-  %c1 = arith.constant 1 : index
-  %c0 = arith.constant 0 : index
-  %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
-    %ini = "some_def"() : () -> (vector<128xf32>)
-    %ini1 = "some_def"() : () -> (vector<128xf32>)
-    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
-      %add = arith.addi %arg3, %c1 : index
-      %1  = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
-      %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
-      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
-    }
-    gpu.yield %3#0 : vector<128xf32>
-  }
-  "some_use"(%0) : (vector<4xf32>) -> ()
-  return
-}
-
-
 // -----
 
 // CHECK-PROP-LABEL: func @vector_reduction(



More information about the Mlir-commits mailing list