[Mlir-commits] [mlir] 10dc8bc - [mlir][vector] Fix for WarpOpScfForOp failure when scf.for has results that are unused. (#141853)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 9 11:56:37 PDT 2025
Author: Charitha Saumya
Date: 2025-06-09T11:56:34-07:00
New Revision: 10dc8bc519130f491d70318bd8b47555307cdc3f
URL: https://github.com/llvm/llvm-project/commit/10dc8bc519130f491d70318bd8b47555307cdc3f
DIFF: https://github.com/llvm/llvm-project/commit/10dc8bc519130f491d70318bd8b47555307cdc3f.diff
LOG: [mlir][vector] Fix for WarpOpScfForOp failure when scf.for has results that are unused. (#141853)
Currently, only the values defined outside ForOp but inside the original
WarpOp are considered "escaping values". However this is not true if the
ForOp has some unused results. In this case, corresponding IterArgs must
also be yielded by the original WarpOp. This PR adds the required code
changes to achieve this.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 045c192787f10..52a9cedb43cc0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1554,22 +1554,36 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> inputTypes;
SmallVector<Type> distTypes;
+ auto collectEscapingValues = [&](Value value) {
+ if (!escapingValues.insert(value))
+ return;
+ Type distType = value.getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(value);
+ distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ inputTypes.push_back(value.getType());
+ distTypes.push_back(distType);
+ };
+
mlir::visitUsedValuesDefinedAbove(
forOp.getBodyRegion(), [&](OpOperand *operand) {
Operation *parent = operand->get().getParentRegion()->getParentOp();
if (warpOp->isAncestor(parent)) {
- if (!escapingValues.insert(operand->get()))
- return;
- Type distType = operand->get().getType();
- if (auto vecType = dyn_cast<VectorType>(distType)) {
- AffineMap map = distributionMapFn(operand->get());
- distType = getDistributedType(vecType, map, warpOp.getWarpSize());
- }
- inputTypes.push_back(operand->get().getType());
- distTypes.push_back(distType);
+ collectEscapingValues(operand->get());
}
});
+ // Any forOp result that is not already yielded by the warpOp
+ // region is also considered escaping and must be returned by the
+ // original warpOp.
+ for (OpResult forResult : forOp.getResults()) {
+ // Check if this forResult is already yielded by the yield op.
+ if (llvm::is_contained(yield->getOperands(), forResult))
+ continue;
+ collectEscapingValues(forResult);
+ }
+
if (llvm::is_contained(distTypes, Type{}))
return failure();
@@ -1609,7 +1623,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
forOp.getResultTypes().end());
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
- warpInput.push_back(newWarpOp.getResult(retIdx));
+ auto newWarpResult = newWarpOp.getResult(retIdx);
+ // Unused forOp results yielded by the warpOp region are already included
+ // in the new ForOp.
+ if (llvm::is_contained(newOperands, newWarpResult))
+ continue;
+ warpInput.push_back(newWarpResult);
argIndexMapping[escapingValues[i]] = warpInputType.size();
warpInputType.push_back(inputTypes[i]);
}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 38771f2593449..6c7ac7a5196a7 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,6 +584,42 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
return
}
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_yield(
+// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
+// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_unused_yield(%arg0: index) {
+ %c128 = arith.constant 128 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+ %ini = "some_def"() : () -> (vector<128xf32>)
+ %ini1 = "some_def"() : () -> (vector<128xf32>)
+ %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
+ %add = arith.addi %arg3, %c1 : index
+ %1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
+ %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
+ scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
+ }
+ gpu.yield %3#0 : vector<128xf32>
+ }
+ "some_use"(%0) : (vector<4xf32>) -> ()
+ return
+}
+
+
// -----
// CHECK-PROP-LABEL: func @vector_reduction(
More information about the Mlir-commits
mailing list