[Mlir-commits] [mlir] 3092b76 - [mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (#147620)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 11 13:08:37 PDT 2025


Author: Charitha Saumya
Date: 2025-07-11T13:08:33-07:00
New Revision: 3092b765ba0b2d20bd716944dda86ea8e4ad12e3

URL: https://github.com/llvm/llvm-project/commit/3092b765ba0b2d20bd716944dda86ea8e4ad12e3
DIFF: https://github.com/llvm/llvm-project/commit/3092b765ba0b2d20bd716944dda86ea8e4ad12e3.diff

LOG: [mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (#147620)

Current implementation generates incorrect code or crashes in the
following valid cases.

1. At least one of the for op results are not yielded by the warpOp.
Example:
```
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
    ....
    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
      
      %1  = ...
      %acc = ....
      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
    }
    gpu.yield %3#0 : vector<128xf32> // %3#1 is not used but can not be removed as dead code (loop carried).
  }
  "some_use"(%0) : (vector<4xf32>) -> ()
  return
```
2. Enclosing warpOp yields the forOp results in different order compared
to the forOp results.
Example:
```
  %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
    ....
    %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
      .....
      scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
    }
    gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32> // swapped order
  }
  "some_use_1"(%0#0) : (vector<4xf32>) -> ()
  "some_use_2"(%0#1) : (vector<4xf32>) -> ()
  "some_use_3"(%0#2) : (vector<8xf32>) -> ()

```

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c8566b1ff83ef..e62031412eab6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1704,19 +1704,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
+    auto warpOpYield = cast<gpu::YieldOp>(
         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-    // Only pick up forOp if it is the last op in the region.
-    Operation *lastNode = yield->getPrevNode();
+    // Only pick up `ForOp` if it is the last op in the region.
+    Operation *lastNode = warpOpYield->getPrevNode();
     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
     if (!forOp)
       return failure();
-    // Collect Values that come from the warp op but are outside the forOp.
-    // Those Value needs to be returned by the original warpOp and passed to
-    // the new op.
+    // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
+    // Those Values need to be returned by the new warp op.
     llvm::SmallSetVector<Value, 32> escapingValues;
-    SmallVector<Type> inputTypes;
-    SmallVector<Type> distTypes;
+    SmallVector<Type> escapingValueInputTypes;
+    SmallVector<Type> escapingValueDistTypes;
     mlir::visitUsedValuesDefinedAbove(
         forOp.getBodyRegion(), [&](OpOperand *operand) {
           Operation *parent = operand->get().getParentRegion()->getParentOp();
@@ -1728,81 +1727,153 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
               AffineMap map = distributionMapFn(operand->get());
               distType = getDistributedType(vecType, map, warpOp.getWarpSize());
             }
-            inputTypes.push_back(operand->get().getType());
-            distTypes.push_back(distType);
+            escapingValueInputTypes.push_back(operand->get().getType());
+            escapingValueDistTypes.push_back(distType);
           }
         });
 
-    if (llvm::is_contained(distTypes, Type{}))
+    if (llvm::is_contained(escapingValueDistTypes, Type{}))
       return failure();
-
-    SmallVector<size_t> newRetIndices;
-    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
-        newRetIndices);
-    yield = cast<gpu::YieldOp>(
-        newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-
-    SmallVector<Value> newOperands;
-    SmallVector<unsigned> resultIdx;
-    // Collect all the outputs coming from the forOp.
-    for (OpOperand &yieldOperand : yield->getOpOperands()) {
-      if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
+    // `WarpOp` can yield two types of values:
+    // 1. Values that are not results of the `ForOp`:
+    //    These values must also be yielded by the new `WarpOp`. Also, we need
+    //    to record the index mapping for these values to replace them later.
+    // 2. Values that are results of the `ForOp`:
+    //    In this case, we record the index mapping between the `WarpOp` result
+    //    index and matching `ForOp` result index.
+    SmallVector<Value> nonForYieldedValues;
+    SmallVector<unsigned> nonForResultIndices;
+    llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
+    for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
+      // Yielded value is not a result of the forOp.
+      if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
+        nonForYieldedValues.push_back(yieldOperand.get());
+        nonForResultIndices.push_back(yieldOperand.getOperandNumber());
         continue;
-      auto forResult = cast<OpResult>(yieldOperand.get());
-      newOperands.push_back(
-          newWarpOp.getResult(yieldOperand.getOperandNumber()));
-      yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
-      resultIdx.push_back(yieldOperand.getOperandNumber());
+      }
+      OpResult forResult = cast<OpResult>(yieldOperand.get());
+      forResultMapping[yieldOperand.getOperandNumber()] =
+          forResult.getResultNumber();
     }
 
+    // Newly created `WarpOp` will yield values in following order:
+    // 1. All init args of the `ForOp`.
+    // 2. All escaping values.
+    // 3. All non-`ForOp` yielded values.
+    SmallVector<Value> newWarpOpYieldValues;
+    SmallVector<Type> newWarpOpDistTypes;
+    for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
+      newWarpOpYieldValues.push_back(initArg);
+      // Compute the distributed type for this init arg.
+      Type distType = initArg.getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(initArg);
+        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      newWarpOpDistTypes.push_back(distType);
+    }
+    // Insert escaping values and their distributed types.
+    newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
+                                escapingValues.begin(), escapingValues.end());
+    newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+                              escapingValueDistTypes.begin(),
+                              escapingValueDistTypes.end());
+    // Next, we insert all non-`ForOp` yielded values and their distributed
+    // types. We also create a mapping between the non-`ForOp` yielded value
+    // index and the corresponding new `WarpOp` yield value index (needed to
+    // update users later).
+    llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
+    for (auto [i, v] :
+         llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
+      nonForResultMapping[i] = newWarpOpYieldValues.size();
+      newWarpOpYieldValues.push_back(v);
+      newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
+    }
+    // Create the new `WarpOp` with the updated yield values and types.
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+
+    // Next, we create a new `ForOp` with the init args yielded by the new
+    // `WarpOp`.
+    const unsigned escapingValuesStartIdx =
+        forOp.getInitArgs().size(); // `ForOp` init args are positioned before
+                                    // escaping values in the new `WarpOp`.
+    SmallVector<Value> newForOpOperands;
+    for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+      newForOpOperands.push_back(newWarpOp.getResult(i));
+
+    // Create a new `ForOp` outside the new `WarpOp` region.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
-
-    // Create a new for op outside the region with a WarpExecuteOnLane0Op
-    // region inside.
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), newOperands);
+        forOp.getStep(), newForOpOperands);
+    // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
+    // newly created `ForOp`. This `WarpOp` will contain all ops that were
+    // contained within the original `ForOp` body.
     rewriter.setInsertionPointToStart(newForOp.getBody());
 
-    SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
-                                 newForOp.getRegionIterArgs().end());
-    SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
-                                    forOp.getResultTypes().end());
+    SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
+                                      newForOp.getRegionIterArgs().end());
+    SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
+                                         forOp.getResultTypes().end());
+    // Escaping values are forwarded to the inner `WarpOp` as its (additional)
+    // arguments. We keep track of the mapping between these values and their
+    // argument index in the inner `WarpOp` (to replace users later).
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
-    for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
-      warpInput.push_back(newWarpOp.getResult(retIdx));
-      argIndexMapping[escapingValues[i]] = warpInputType.size();
-      warpInputType.push_back(inputTypes[i]);
+    for (size_t i = escapingValuesStartIdx;
+         i < escapingValuesStartIdx + escapingValues.size(); ++i) {
+      innerWarpInput.push_back(newWarpOp.getResult(i));
+      argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
+          innerWarpInputType.size();
+      innerWarpInputType.push_back(
+          escapingValueInputTypes[i - escapingValuesStartIdx]);
     }
+    // Create the inner `WarpOp` with the new input values and types.
     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
         newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
-        newWarpOp.getWarpSize(), warpInput, warpInputType);
+        newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
 
+    // Inline the `ForOp` body into the inner `WarpOp` body.
     SmallVector<Value> argMapping;
     argMapping.push_back(newForOp.getInductionVar());
-    for (Value args : innerWarp.getBody()->getArguments()) {
+    for (Value args : innerWarp.getBody()->getArguments())
       argMapping.push_back(args);
-    }
+
     argMapping.resize(forOp.getBody()->getNumArguments());
     SmallVector<Value> yieldOperands;
     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
       yieldOperands.push_back(operand);
+
     rewriter.eraseOp(forOp.getBody()->getTerminator());
     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
+
+    // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
+    // original `ForOp` results.
     rewriter.setInsertionPointToEnd(innerWarp.getBody());
     rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
     rewriter.setInsertionPointAfter(innerWarp);
+    // Insert a scf.yield op at the end of the new `ForOp` body that yields
+    // the inner `WarpOp` results.
     if (!innerWarp.getResults().empty())
       rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
+
+    // Update the users of original `WarpOp` results that were coming from the
+    // original `ForOp` to the corresponding new `ForOp` result.
+    for (auto [origIdx, newIdx] : forResultMapping)
+      rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+                                    newForOp.getResult(newIdx), newForOp);
+    // Similarly, update any users of the `WarpOp` results that were not
+    // results of the `ForOp`.
+    for (auto [origIdx, newIdx] : nonForResultMapping)
+      rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
+                                  newWarpOp.getResult(newIdx));
+    // Remove the original `WarpOp` and `ForOp`, they should not have any uses
+    // at this point.
     rewriter.eraseOp(forOp);
-    // Replace the warpOp result coming from the original ForOp.
-    for (const auto &res : llvm::enumerate(resultIdx)) {
-      rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
-                                  newForOp.getResult(res.index()));
-      newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
-    }
+    rewriter.eraseOp(warpOp);
+    // Update any users of escaping values that were forwarded to the
+    // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
     newForOp.walk([&](Operation *op) {
       for (OpOperand &operand : op->getOpOperands()) {
         auto it = argIndexMapping.find(operand.get());
@@ -1812,7 +1883,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       }
     });
 
-    // Finally, hoist out any now uniform code from the inner warp op.
+    // Finally, hoist out any now uniform code from the inner `WarpOp`.
     mlir::vector::moveScalarUniformCode(innerWarp);
     return success();
   }

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c072557c2bd22..5319496edc5af 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -876,15 +876,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   // Step 3: Apply subgroup to workitem distribution patterns.
   RewritePatternSet patterns(&getContext());
   xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
-  // TODO: distributionFn and shuffleFn are not used at this point.
+  // distributionFn is used by vector distribution patterns to determine the
+  // distributed vector type for a given vector value. In XeGPU subgroup
+  // distribution context, we compute this based on lane layout.
   auto distributionFn = [](Value val) {
     VectorType vecType = dyn_cast<VectorType>(val.getType());
     int64_t vecRank = vecType ? vecType.getRank() : 0;
-    OpBuilder builder(val.getContext());
     if (vecRank == 0)
       return AffineMap::get(val.getContext());
-    return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
+    // Get the layout of the vector type.
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
+    // If no layout is specified, assume the inner most dimension is distributed
+    // for now.
+    if (!layout)
+      return AffineMap::getMultiDimMapWithTargets(
+          vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
+    SmallVector<unsigned int> distributedDims;
+    // Get the distributed dimensions based on the layout.
+    ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
+    for (unsigned i = 0; i < laneLayout.size(); ++i) {
+      if (laneLayout[i] > 1)
+        distributedDims.push_back(i);
+    }
+    return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
+                                                val.getContext());
   };
+  // TODO: shuffleFn is not used.
   auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
                       int64_t warpSz) { return Value(); };
   vector::populatePropagateWarpVectorDistributionPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 9fa9d56e4a324..c6342f07fc314 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,6 +584,85 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
   return
 }
 
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result(
+//       CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:  %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP:  %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP:  gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:  %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:    %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
+//       CHECK-PROP:    %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
+//       CHECK-PROP:    gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP:  }
+//       CHECK-PROP:  scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_unused_for_result(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %ini1 = "some_def"() : () -> (vector<128xf32>)
+    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
+      %add = arith.addi %arg3, %c1 : index
+      %1  = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
+      %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
+    }
+    gpu.yield %3#0 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results(
+//       CHECK-PROP:  %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:    %[[INI0:.*]] = "some_def"() : () -> vector<256xf32>
+//  CHECK-PROP-NEXT:    %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+//  CHECK-PROP-NEXT:    %[[INI2:.*]] = "some_def"() : () -> vector<128xf32>
+//  CHECK-PROP-NEXT:    gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+//  CHECK-PROP-NEXT:  }
+//  CHECK-PROP-NEXT:  %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:    %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} :
+//  CHECK-PROP-SAME:        vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:      ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>):
+//  CHECK-PROP-NEXT:        %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32>
+//  CHECK-PROP-NEXT:        %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+//  CHECK-PROP-NEXT:        %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+//  CHECK-PROP-NEXT:        gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+//  CHECK-PROP-NEXT:    }
+//  CHECK-PROP-NEXT:    scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32>
+//  CHECK-PROP-NEXT:  }
+//  CHECK-PROP-NEXT:  "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> ()
+//  CHECK-PROP-NEXT:  "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> ()
+//  CHECK-PROP-NEXT:  "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> ()
+func.func @warp_scf_for_swapped_for_results(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
+    %ini1 = "some_def"() : () -> (vector<256xf32>)
+    %ini2 = "some_def"() : () -> (vector<128xf32>)
+    %ini3 = "some_def"() : () -> (vector<128xf32>)
+    %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
+      %acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>)
+      %acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>)
+      %acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
+    }
+    gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32>
+  }
+  "some_use_1"(%0#0) : (vector<4xf32>) -> ()
+  "some_use_2"(%0#1) : (vector<4xf32>) -> ()
+  "some_use_3"(%0#2) : (vector<8xf32>) -> ()
+  return
+}
+
 // -----
 
 // CHECK-PROP-LABEL: func @vector_reduction(


        


More information about the Mlir-commits mailing list