[Mlir-commits] [mlir] c2b9529 - [mlir][vector] Fix n-d transfer write distribution (#83215)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 27 21:11:32 PST 2024
Author: Quinn Dawkins
Date: 2024-02-28T00:11:28-05:00
New Revision: c2b952926fe8707527cf1b8bab211dc4c7ab9aee
URL: https://github.com/llvm/llvm-project/commit/c2b952926fe8707527cf1b8bab211dc4c7ab9aee
DIFF: https://github.com/llvm/llvm-project/commit/c2b952926fe8707527cf1b8bab211dc4c7ab9aee.diff
LOG: [mlir][vector] Fix n-d transfer write distribution (#83215)
Currently n-d transfer write distribution can be inconsistent with
distribution of reductions if a value has multiple users, one of which
is a transfer_write with a non-standard distribution map, and the other
of which is a vector.reduction.
We may want to consider removing the distribution map functionality in
the future for this reason.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 620ceee48b196d..b3ab4a916121e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -443,15 +443,24 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
/// d1) and return vector<16x2x64>
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
- if (map.getNumResults() != 1)
- return VectorType();
SmallVector<int64_t> targetShape(originalType.getShape().begin(),
originalType.getShape().end());
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
- if (targetShape[position] % warpSize != 0)
- return VectorType();
+ if (targetShape[position] % warpSize != 0) {
+ if (warpSize % targetShape[position] != 0) {
+ return VectorType();
+ }
+ warpSize /= targetShape[position];
+ targetShape[position] = 1;
+ continue;
+ }
targetShape[position] = targetShape[position] / warpSize;
+ warpSize = 1;
+ break;
+ }
+ if (warpSize != 1) {
+ return VectorType();
}
VectorType targetType =
VectorType::get(targetShape, originalType.getElementType());
@@ -526,7 +535,30 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
// 4. Reindex the write using the distribution map.
auto newWarpOp =
newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
+
+ // Delinearize the lane id based on the way threads are divided across the
+ // vector. To get the number of threads per vector dimension, divide the
+ // sequential size by the distributed size along each dim.
rewriter.setInsertionPoint(newWriteOp);
+ SmallVector<OpFoldResult> delinearizedIdSizes;
+ for (auto [seqSize, distSize] :
+ llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
+ assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
+ delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
+ }
+ SmallVector<Value> delinearized;
+ if (map.getNumResults() > 1) {
+ delinearized = rewriter
+ .create<mlir::affine::AffineDelinearizeIndexOp>(
+ newWarpOp.getLoc(), newWarpOp.getLaneid(),
+ delinearizedIdSizes)
+ .getResults();
+ } else {
+ // If there is only one map result, we can elide the delinearization
+ // op and use the lane id directly.
+ delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
+ }
+
AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
Location loc = newWriteOp.getLoc();
SmallVector<Value> indices(newWriteOp.getIndices().begin(),
@@ -539,11 +571,11 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
+ Value laneId = delinearized[vectorPos];
auto scale =
rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
indices[indexPos] = affine::makeComposedAffineApply(
- rewriter, loc, d0 + scale * d1,
- {indices[indexPos], newWarpOp.getLaneid()});
+ rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
}
newWriteOp.getIndicesMutable().assign(indices);
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 9072603734879e..bf90c4a6ebb3c2 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1559,3 +1559,28 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
// CHECK-PROP: %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
// CHECK-PROP: %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
// CHECK-PROP: vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
+
+// -----
+
+func.func @warp_propagate_nd_write(%laneid: index, %dest: memref<4x1024xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.warp_execute_on_lane_0(%laneid)[32] -> () {
+ %0 = "some_def"() : () -> (vector<4x1024xf32>)
+ vector.transfer_write %0, %dest[%c0, %c0] : vector<4x1024xf32>, memref<4x1024xf32>
+ vector.yield
+ }
+ return
+}
+
+// CHECK-DIST-AND-PROP: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 128)>
+
+// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_nd_write(
+// CHECK-DIST-AND-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1x128xf32>) {
+// CHECK-DIST-AND-PROP: %[[V0:.*]] = "some_def"
+// CHECK-DIST-AND-PROP: vector.yield %[[V0]]
+// CHECK-DIST-AND-PROP-SAME: vector<4x1024xf32>
+// CHECK-DIST-AND-PROP: }
+
+// CHECK-DIST-AND-PROP: %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (%c4, %c8) : index, index
+// CHECK-DIST-AND-PROP: %[[INNER_ID:.+]] = affine.apply #map()[%[[IDS]]#1]
+// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]], %{{.*}}[%[[IDS]]#0, %[[INNER_ID]]] {{.*}} : vector<1x128xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 178a58e796b246..915f713f7047be 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -630,15 +630,13 @@ struct TestVectorDistribution
});
MLIRContext *ctx = &getContext();
auto distributionFn = [](Value val) {
- // Create a map (d0, d1) -> (d1) to distribute along the inner
- // dimension. Once we support n-d distribution we can add more
- // complex cases.
+ // Create an identity dim map of the same rank as the vector.
VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
OpBuilder builder(val.getContext());
if (vecRank == 0)
return AffineMap::get(val.getContext());
- return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+ return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
};
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
Value srcIdx, int64_t warpSz) {
More information about the Mlir-commits
mailing list