[Mlir-commits] [mlir] [mlir][vector] Fix n-d transfer write distribution (PR #83215)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 27 18:10:28 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Quinn Dawkins (qedawkins)
<details>
<summary>Changes</summary>
Currently n-d transfer write distribution can be inconsistent with distribution of reductions if a value has multiple users, one of which is a transfer_write with a non-standard distribution map, and the other of which is a vector.reduction.
We may want to consider removing the distribution map functionality in the future for this reason.
---
Full diff: https://github.com/llvm/llvm-project/pull/83215.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+33-6)
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+25)
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+1-1)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 620ceee48b196d..34f25d2312ef8f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -443,15 +443,24 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
/// d1) and return vector<16x2x64>
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
- if (map.getNumResults() != 1)
- return VectorType();
SmallVector<int64_t> targetShape(originalType.getShape().begin(),
originalType.getShape().end());
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
- if (targetShape[position] % warpSize != 0)
- return VectorType();
+ if (targetShape[position] % warpSize != 0) {
+ if (warpSize % targetShape[position] != 0) {
+ return VectorType();
+ }
+ warpSize /= targetShape[position];
+ targetShape[position] = 1;
+ continue;
+ }
targetShape[position] = targetShape[position] / warpSize;
+ warpSize = 1;
+ break;
+ }
+ if (warpSize != 1) {
+ return VectorType();
}
VectorType targetType =
VectorType::get(targetShape, originalType.getElementType());
@@ -526,7 +535,25 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
// 4. Reindex the write using the distribution map.
auto newWarpOp =
newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
+
rewriter.setInsertionPoint(newWriteOp);
+ SmallVector<OpFoldResult> delinearizedIdSizes;
+ for (auto [seqSize, distSize] :
+ llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
+ assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
+ delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
+ }
+ SmallVector<Value> delinearized;
+ if (map.getNumResults() > 1) {
+ delinearized = rewriter
+ .create<mlir::affine::AffineDelinearizeIndexOp>(
+ newWarpOp.getLoc(), newWarpOp.getLaneid(),
+ delinearizedIdSizes)
+ .getResults();
+ } else {
+ delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
+ }
+
AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
Location loc = newWriteOp.getLoc();
SmallVector<Value> indices(newWriteOp.getIndices().begin(),
@@ -539,11 +566,11 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
+ Value laneId = delinearized[vectorPos];
auto scale =
rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
indices[indexPos] = affine::makeComposedAffineApply(
- rewriter, loc, d0 + scale * d1,
- {indices[indexPos], newWarpOp.getLaneid()});
+ rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
}
newWriteOp.getIndicesMutable().assign(indices);
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 9072603734879e..bf90c4a6ebb3c2 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1559,3 +1559,28 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
+
+// -----
+
+func.func @warp_propagate_nd_write(%laneid: index, %dest: memref<4x1024xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.warp_execute_on_lane_0(%laneid)[32] -> () {
+ %0 = "some_def"() : () -> (vector<4x1024xf32>)
+ vector.transfer_write %0, %dest[%c0, %c0] : vector<4x1024xf32>, memref<4x1024xf32>
+ vector.yield
+ }
+ return
+}
+
+// CHECK-DIST-AND-PROP: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 128)>
+
+// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_nd_write(
+// CHECK-DIST-AND-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1x128xf32>) {
+// CHECK-DIST-AND-PROP: %[[V0:.*]] = "some_def"
+// CHECK-DIST-AND-PROP: vector.yield %[[V0]]
+// CHECK-DIST-AND-PROP-SAME: vector<4x1024xf32>
+// CHECK-DIST-AND-PROP: }
+
+// CHECK-DIST-AND-PROP: %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (%c4, %c8) : index, index
+// CHECK-DIST-AND-PROP: %[[INNER_ID:.+]] = affine.apply #map()[%[[IDS]]#1]
+// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]], %{{.*}}[%[[IDS]]#0, %[[INNER_ID]]] {{.*}} : vector<1x128xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 178a58e796b246..d8453c78d52d4f 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -638,7 +638,7 @@ struct TestVectorDistribution
OpBuilder builder(val.getContext());
if (vecRank == 0)
return AffineMap::get(val.getContext());
- return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+ return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
};
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
Value srcIdx, int64_t warpSz) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/83215
More information about the Mlir-commits
mailing list