[Mlir-commits] [mlir] 3092b76 - [mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (#147620)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 11 13:08:37 PDT 2025
Author: Charitha Saumya
Date: 2025-07-11T13:08:33-07:00
New Revision: 3092b765ba0b2d20bd716944dda86ea8e4ad12e3
URL: https://github.com/llvm/llvm-project/commit/3092b765ba0b2d20bd716944dda86ea8e4ad12e3
DIFF: https://github.com/llvm/llvm-project/commit/3092b765ba0b2d20bd716944dda86ea8e4ad12e3.diff
LOG: [mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (#147620)
Current implementation generates incorrect code or crashes in the
following valid cases.
1. At least one of the for op results are not yielded by the warpOp.
Example:
```
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
....
%3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
%1 = ...
%acc = ....
scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
}
gpu.yield %3#0 : vector<128xf32> // %3#1 is not used but can not be removed as dead code (loop carried).
}
"some_use"(%0) : (vector<4xf32>) -> ()
return
```
2. Enclosing warpOp yields the forOp results in different order compared
to the forOp results.
Example:
```
%0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
....
%3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
.....
scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
}
gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32> // swapped order
}
"some_use_1"(%0#0) : (vector<4xf32>) -> ()
"some_use_2"(%0#1) : (vector<4xf32>) -> ()
"some_use_3"(%0#2) : (vector<8xf32>) -> ()
```
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c8566b1ff83ef..e62031412eab6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1704,19 +1704,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
+ auto warpOpYield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- // Only pick up forOp if it is the last op in the region.
- Operation *lastNode = yield->getPrevNode();
+ // Only pick up `ForOp` if it is the last op in the region.
+ Operation *lastNode = warpOpYield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
if (!forOp)
return failure();
- // Collect Values that come from the warp op but are outside the forOp.
- // Those Value needs to be returned by the original warpOp and passed to
- // the new op.
+ // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
+ // Those Values need to be returned by the new warp op.
llvm::SmallSetVector<Value, 32> escapingValues;
- SmallVector<Type> inputTypes;
- SmallVector<Type> distTypes;
+ SmallVector<Type> escapingValueInputTypes;
+ SmallVector<Type> escapingValueDistTypes;
mlir::visitUsedValuesDefinedAbove(
forOp.getBodyRegion(), [&](OpOperand *operand) {
Operation *parent = operand->get().getParentRegion()->getParentOp();
@@ -1728,81 +1727,153 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
- inputTypes.push_back(operand->get().getType());
- distTypes.push_back(distType);
+ escapingValueInputTypes.push_back(operand->get().getType());
+ escapingValueDistTypes.push_back(distType);
}
});
- if (llvm::is_contained(distTypes, Type{}))
+ if (llvm::is_contained(escapingValueDistTypes, Type{}))
return failure();
-
- SmallVector<size_t> newRetIndices;
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
- newRetIndices);
- yield = cast<gpu::YieldOp>(
- newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-
- SmallVector<Value> newOperands;
- SmallVector<unsigned> resultIdx;
- // Collect all the outputs coming from the forOp.
- for (OpOperand &yieldOperand : yield->getOpOperands()) {
- if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
+ // `WarpOp` can yield two types of values:
+ // 1. Values that are not results of the `ForOp`:
+ // These values must also be yielded by the new `WarpOp`. Also, we need
+ // to record the index mapping for these values to replace them later.
+ // 2. Values that are results of the `ForOp`:
+ // In this case, we record the index mapping between the `WarpOp` result
+ // index and matching `ForOp` result index.
+ SmallVector<Value> nonForYieldedValues;
+ SmallVector<unsigned> nonForResultIndices;
+ llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
+ for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
+ // Yielded value is not a result of the forOp.
+ if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
+ nonForYieldedValues.push_back(yieldOperand.get());
+ nonForResultIndices.push_back(yieldOperand.getOperandNumber());
continue;
- auto forResult = cast<OpResult>(yieldOperand.get());
- newOperands.push_back(
- newWarpOp.getResult(yieldOperand.getOperandNumber()));
- yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
- resultIdx.push_back(yieldOperand.getOperandNumber());
+ }
+ OpResult forResult = cast<OpResult>(yieldOperand.get());
+ forResultMapping[yieldOperand.getOperandNumber()] =
+ forResult.getResultNumber();
}
+ // Newly created `WarpOp` will yield values in following order:
+ // 1. All init args of the `ForOp`.
+ // 2. All escaping values.
+ // 3. All non-`ForOp` yielded values.
+ SmallVector<Value> newWarpOpYieldValues;
+ SmallVector<Type> newWarpOpDistTypes;
+ for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
+ newWarpOpYieldValues.push_back(initArg);
+ // Compute the distributed type for this init arg.
+ Type distType = initArg.getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(initArg);
+ distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ newWarpOpDistTypes.push_back(distType);
+ }
+ // Insert escaping values and their distributed types.
+ newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
+ escapingValues.begin(), escapingValues.end());
+ newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+ escapingValueDistTypes.begin(),
+ escapingValueDistTypes.end());
+ // Next, we insert all non-`ForOp` yielded values and their distributed
+ // types. We also create a mapping between the non-`ForOp` yielded value
+ // index and the corresponding new `WarpOp` yield value index (needed to
+ // update users later).
+ llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
+ for (auto [i, v] :
+ llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
+ nonForResultMapping[i] = newWarpOpYieldValues.size();
+ newWarpOpYieldValues.push_back(v);
+ newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
+ }
+ // Create the new `WarpOp` with the updated yield values and types.
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+
+ // Next, we create a new `ForOp` with the init args yielded by the new
+ // `WarpOp`.
+ const unsigned escapingValuesStartIdx =
+ forOp.getInitArgs().size(); // `ForOp` init args are positioned before
+ // escaping values in the new `WarpOp`.
+ SmallVector<Value> newForOpOperands;
+ for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+ newForOpOperands.push_back(newWarpOp.getResult(i));
+
+ // Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
-
- // Create a new for op outside the region with a WarpExecuteOnLane0Op
- // region inside.
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newOperands);
+ forOp.getStep(), newForOpOperands);
+ // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
+ // newly created `ForOp`. This `WarpOp` will contain all ops that were
+ // contained within the original `ForOp` body.
rewriter.setInsertionPointToStart(newForOp.getBody());
- SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
- newForOp.getRegionIterArgs().end());
- SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
- forOp.getResultTypes().end());
+ SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
+ newForOp.getRegionIterArgs().end());
+ SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
+ forOp.getResultTypes().end());
+ // Escaping values are forwarded to the inner `WarpOp` as its (additional)
+ // arguments. We keep track of the mapping between these values and their
+ // argument index in the inner `WarpOp` (to replace users later).
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
- for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
- warpInput.push_back(newWarpOp.getResult(retIdx));
- argIndexMapping[escapingValues[i]] = warpInputType.size();
- warpInputType.push_back(inputTypes[i]);
+ for (size_t i = escapingValuesStartIdx;
+ i < escapingValuesStartIdx + escapingValues.size(); ++i) {
+ innerWarpInput.push_back(newWarpOp.getResult(i));
+ argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
+ innerWarpInputType.size();
+ innerWarpInputType.push_back(
+ escapingValueInputTypes[i - escapingValuesStartIdx]);
}
+ // Create the inner `WarpOp` with the new input values and types.
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
- newWarpOp.getWarpSize(), warpInput, warpInputType);
+ newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
+ // Inline the `ForOp` body into the inner `WarpOp` body.
SmallVector<Value> argMapping;
argMapping.push_back(newForOp.getInductionVar());
- for (Value args : innerWarp.getBody()->getArguments()) {
+ for (Value args : innerWarp.getBody()->getArguments())
argMapping.push_back(args);
- }
+
argMapping.resize(forOp.getBody()->getNumArguments());
SmallVector<Value> yieldOperands;
for (Value operand : forOp.getBody()->getTerminator()->getOperands())
yieldOperands.push_back(operand);
+
rewriter.eraseOp(forOp.getBody()->getTerminator());
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
+
+ // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
+ // original `ForOp` results.
rewriter.setInsertionPointToEnd(innerWarp.getBody());
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
rewriter.setInsertionPointAfter(innerWarp);
+ // Insert a scf.yield op at the end of the new `ForOp` body that yields
+ // the inner `WarpOp` results.
if (!innerWarp.getResults().empty())
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
+
+ // Update the users of original `WarpOp` results that were coming from the
+ // original `ForOp` to the corresponding new `ForOp` result.
+ for (auto [origIdx, newIdx] : forResultMapping)
+ rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ newForOp.getResult(newIdx), newForOp);
+ // Similarly, update any users of the `WarpOp` results that were not
+ // results of the `ForOp`.
+ for (auto [origIdx, newIdx] : nonForResultMapping)
+ rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
+ newWarpOp.getResult(newIdx));
+ // Remove the original `WarpOp` and `ForOp`, they should not have any uses
+ // at this point.
rewriter.eraseOp(forOp);
- // Replace the warpOp result coming from the original ForOp.
- for (const auto &res : llvm::enumerate(resultIdx)) {
- rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
- newForOp.getResult(res.index()));
- newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
- }
+ rewriter.eraseOp(warpOp);
+ // Update any users of escaping values that were forwarded to the
+ // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
auto it = argIndexMapping.find(operand.get());
@@ -1812,7 +1883,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
}
});
- // Finally, hoist out any now uniform code from the inner warp op.
+ // Finally, hoist out any now uniform code from the inner `WarpOp`.
mlir::vector::moveScalarUniformCode(innerWarp);
return success();
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c072557c2bd22..5319496edc5af 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -876,15 +876,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Step 3: Apply subgroup to workitem distribution patterns.
RewritePatternSet patterns(&getContext());
xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
- // TODO: distributionFn and shuffleFn are not used at this point.
+ // distributionFn is used by vector distribution patterns to determine the
+ // distributed vector type for a given vector value. In XeGPU subgroup
+ // distribution context, we compute this based on lane layout.
auto distributionFn = [](Value val) {
VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
- OpBuilder builder(val.getContext());
if (vecRank == 0)
return AffineMap::get(val.getContext());
- return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
+ // Get the layout of the vector type.
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
+ // If no layout is specified, assume the inner most dimension is distributed
+ // for now.
+ if (!layout)
+ return AffineMap::getMultiDimMapWithTargets(
+ vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
+ SmallVector<unsigned int> distributedDims;
+ // Get the distributed dimensions based on the layout.
+ ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
+ for (unsigned i = 0; i < laneLayout.size(); ++i) {
+ if (laneLayout[i] > 1)
+ distributedDims.push_back(i);
+ }
+ return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
+ val.getContext());
};
+ // TODO: shuffleFn is not used.
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
int64_t warpSz) { return Value(); };
vector::populatePropagateWarpVectorDistributionPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 9fa9d56e4a324..c6342f07fc314 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,6 +584,85 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
return
}
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result(
+// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
+// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_unused_for_result(%arg0: index) {
+ %c128 = arith.constant 128 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+ %ini = "some_def"() : () -> (vector<128xf32>)
+ %ini1 = "some_def"() : () -> (vector<128xf32>)
+ %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
+ %add = arith.addi %arg3, %c1 : index
+ %1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
+ %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
+ scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
+ }
+ gpu.yield %3#0 : vector<128xf32>
+ }
+ "some_use"(%0) : (vector<4xf32>) -> ()
+ return
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results(
+// CHECK-PROP: %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP-NEXT: %[[INI0:.*]] = "some_def"() : () -> vector<256xf32>
+// CHECK-PROP-NEXT: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP-NEXT: %[[INI2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP-NEXT: gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+// CHECK-PROP-NEXT: }
+// CHECK-PROP-NEXT: %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP-NEXT: %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} :
+// CHECK-PROP-SAME: vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP-NEXT: ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>):
+// CHECK-PROP-NEXT: %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32>
+// CHECK-PROP-NEXT: %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP-NEXT: %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP-NEXT: gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+// CHECK-PROP-NEXT: }
+// CHECK-PROP-NEXT: scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32>
+// CHECK-PROP-NEXT: }
+// CHECK-PROP-NEXT: "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> ()
+// CHECK-PROP-NEXT: "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> ()
+// CHECK-PROP-NEXT: "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> ()
+func.func @warp_scf_for_swapped_for_results(%arg0: index) {
+ %c128 = arith.constant 128 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
+ %ini1 = "some_def"() : () -> (vector<256xf32>)
+ %ini2 = "some_def"() : () -> (vector<128xf32>)
+ %ini3 = "some_def"() : () -> (vector<128xf32>)
+ %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
+ %acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>)
+ %acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>)
+ %acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>)
+ scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
+ }
+ gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32>
+ }
+ "some_use_1"(%0#0) : (vector<4xf32>) -> ()
+ "some_use_2"(%0#1) : (vector<4xf32>) -> ()
+ "some_use_3"(%0#2) : (vector<8xf32>) -> ()
+ return
+}
+
// -----
// CHECK-PROP-LABEL: func @vector_reduction(
More information about the Mlir-commits
mailing list