[Mlir-commits] [mlir] [mlir][vector] Add support for multi-dim reduction vector distribution (PR #71193)
Kunwar Grover
llvmlistbot at llvm.org
Thu Nov 9 06:27:39 PST 2023
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/71193
>From c6319ed1f22676d717afaa5ad5b8875904313528 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 3 Nov 2023 21:12:01 +0530
Subject: [PATCH 1/7] [mlir][vector] Add support for multi-dim reduction vector
distribution
---
.../Vector/Transforms/VectorDistribute.cpp | 49 +++++++++++++++----
1 file changed, 39 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..4f3d93abdf83c69 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -435,23 +435,48 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
/// Return the distributed vector type based on the original type and the
/// distribution map. The map is expected to have a dimension equal to the
/// original type rank and should be a projection where the results are the
-/// distributed dimensions. The number of results should be equal to the number
-/// of warp sizes which is currently limited to 1.
-/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
-/// and a warp size of 16 would distribute the second dimension (associated to
-/// d1) and return vector<16x2x64>
+/// distributed dimensions. The vector should be completely distributably, i.e.
+/// the linearized shape should be a multiple of the warp size.
+/// Example (single-dim): For a vector<16x32x64> distributed with
+/// a map(d0, d1, d2) -> (d1) and a warp size of 16 would distribute the second
+/// dimension (associated to d1) and return vector<16x2x64>.
+/// Example (multi-dim): For a vector<16x32x64> distributed with a
+/// map(d0, d1, d2) -> (d1, d2), and a warp size of 128 would distribute first
+/// the second dimension and then the third dimension, finally returning a
+/// vector <4x1x64>.
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
- if (map.getNumResults() != 1)
- return VectorType();
+ assert(map.isProjectedPermutation() && "expected projected permutation map");
+
SmallVector<int64_t> targetShape(originalType.getShape().begin(),
originalType.getShape().end());
+ // Distribute the vector based on the order of dimensions in the affine map.
+ int64_t availableThreads = warpSize;
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
- if (targetShape[position] % warpSize != 0)
- return VectorType();
- targetShape[position] = targetShape[position] / warpSize;
+ int64_t &dimSize = targetShape[position];
+ if (availableThreads > dimSize) {
+ // We have more threads available than the size of the dimension, so we
+ // distribute the whole dimension.
+ if (availableThreads % dimSize != 0)
+ return VectorType();
+ availableThreads = availableThreads / dimSize;
+ dimSize = 1;
+ } else {
+ // We have the dimension is bigger than the number of threads available,
+ // so we distribute a part of the dimension to each thread.
+ if (dimSize % availableThreads != 0)
+ return VectorType();
+ dimSize = dimSize / availableThreads;
+ availableThreads = 1;
+ break;
+ }
}
+
+ // If we could not distribute the whole vector, we fail.
+ if (availableThreads != 1)
+ return VectorType();
+
VectorType targetType =
VectorType::get(targetShape, originalType.getElementType());
return targetType;
@@ -1512,6 +1537,10 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
});
+ // Check if any types could not be distributed.
+ if (llvm::any_of(distTypes, [](Type t) { return !t; }))
+ return failure();
+
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
>From 8487e5719baaac96e541f5aea63bc189b5ade263 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sat, 4 Nov 2023 12:28:01 +0530
Subject: [PATCH 2/7] Add tests
---
.../Vector/vector-warp-distribute.mlir | 56 +++++++++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 14 +++--
2 files changed, 64 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..29ac67ac9a93371 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -496,6 +496,62 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
// -----
+// CHECK-PROP-LABEL: func @warp_scf_for_multi_reduce(
+// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
+// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
+// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
+// CHECK-PROP: vector.transfer_read {{.*}} : memref<2x32x40x384xf32>, vector<1x4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: }
+// CHECK-PROP: vector.reduction <add>
+// CHECK-PROP: gpu.shuffle
+#map = affine_map<(d0, d1) -> (0, 0)>
+func.func @warp_scf_for_multi_reduce(%arg0: memref<2x32x40x384xf32>, %arg1: memref<2x32x40x384xf16>, %arg2: memref<2x32xf32>, %arg3: memref<2x32x40x384xf16>) {
+ %cst = arith.constant dense<1.536000e+04> : vector<8x128xf32>
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<8x128xf32>
+ %cst_1 = arith.constant 9.99999997E-7 : f32
+ %c128 = arith.constant 128 : index
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c40 = arith.constant 40 : index
+ %c384 = arith.constant 384 : index
+ %cst_2 = arith.constant 0.000000e+00 : f16
+ %cst_3 = arith.constant 0.000000e+00 : f32
+ %0 = gpu.thread_id x
+ %1 = arith.truncf %cst_1 : f32 to f16
+ vector.warp_execute_on_lane_0(%0)[256] {
+ %2 = scf.for %arg4 = %c0 to %c40 step %c8 iter_args(%arg5 = %cst_0) -> (vector<8x128xf32>) {
+ %11 = scf.for %arg6 = %c0 to %c384 step %c128 iter_args(%arg7 = %arg5) -> (vector<8x128xf32>) {
+ %12 = vector.transfer_read %arg0[%c0, %c0, %arg4, %arg6], %cst_3 {in_bounds = [true, true]} : memref<2x32x40x384xf32>, vector<8x128xf32>
+ %13 = arith.addf %12, %arg7 : vector<8x128xf32>
+ scf.yield %13 : vector<8x128xf32>
+ }
+ scf.yield %11 : vector<8x128xf32>
+ }
+ %3 = vector.shape_cast %2 : vector<8x128xf32> to vector<1024xf32>
+ %4 = vector.reduction <add>, %3, %cst_3 : vector<1024xf32> into f32
+ %5 = vector.broadcast %4 : f32 to vector<8x128xf32>
+ %6 = arith.divf %5, %cst : vector<8x128xf32>
+ %7 = arith.truncf %6 : vector<8x128xf32> to vector<8x128xf16>
+ %8 = vector.broadcast %1 : f16 to vector<8x128xf16>
+ %9 = arith.addf %7, %8 : vector<8x128xf16>
+ %10 = math.rsqrt %9 : vector<8x128xf16>
+ scf.for %arg4 = %c0 to %c40 step %c8 {
+ %11 = vector.transfer_read %arg2[%c0, %c0], %cst_3 {in_bounds = [true, true], permutation_map = #map} : memref<2x32xf32>, vector<8x128xf32>
+ %12 = arith.truncf %11 : vector<8x128xf32> to vector<8x128xf16>
+ scf.for %arg5 = %c0 to %c384 step %c128 {
+ %13 = vector.transfer_read %arg1[%c0, %c0, %arg4, %arg5], %cst_2 {in_bounds = [true, true]} : memref<2x32x40x384xf16>, vector<8x128xf16>
+ %14 = arith.subf %13, %12 : vector<8x128xf16>
+ %15 = arith.mulf %14, %10 : vector<8x128xf16>
+ vector.transfer_write %15, %arg3[%c0, %c0, %arg4, %arg5] {in_bounds = [true, true]} : vector<8x128xf16>, memref<2x32x40x384xf16>
+ }
+ }
+ }
+ return
+}
+
+// -----
+
// CHECK-PROP-LABEL: func @vector_reduction(
// CHECK-PROP-SAME: %[[laneid:.*]]: index)
// CHECK-PROP-DAG: %[[c1:.*]] = arith.constant 1 : i32
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 2fbf1babf437f08..cf7799a403a3295 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -569,15 +569,17 @@ struct TestVectorDistribution
});
MLIRContext *ctx = &getContext();
auto distributionFn = [](Value val) {
- // Create a map (d0, d1) -> (d1) to distribute along the inner
- // dimension. Once we support n-d distribution we can add more
- // complex cases.
+ // Create a map (d0, d1) -> (d1, d0) to distribute starting from the inner
+ // dimensions.
VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
OpBuilder builder(val.getContext());
- if (vecRank == 0)
- return AffineMap::get(val.getContext());
- return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+ SmallVector<AffineExpr, 4> vecDims = llvm::to_vector<4>(
+ llvm::map_range(llvm::seq<int64_t>(0, vecRank), [&](int64_t i) {
+ return builder.getAffineDimExpr(vecRank - i - 1);
+ }));
+ return AffineMap::get(vecRank, /*symbolCount=*/0, vecDims,
+ builder.getContext());
};
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
Value srcIdx, int64_t warpSz) {
>From 8775f1c92ee64022fb3db54604229de804810a7f Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 8 Nov 2023 11:13:14 +0530
Subject: [PATCH 3/7] Address comments
---
.../Vector/Transforms/VectorDistribution.h | 10 ++++++
.../Vector/Transforms/VectorDistribute.cpp | 31 +++++++++++++------
.../Vector/vector-warp-distribute.mlir | 11 ++++++-
.../Dialect/Vector/TestVectorTransforms.cpp | 6 ++--
4 files changed, 44 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index a76a58eb5ec6d3c..781a46117602aae 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -40,6 +40,16 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
const WarpExecuteOnLane0LoweringOptions &options,
PatternBenefit benefit = 1);
+/// Given a value having a shaped type, returns the distribution map for that
+/// value. The distribution map represents the order of dimensions in which
+/// the shape should be distributed. The map is expected to be a projected
+/// permutation of the shape dimensions. Examples of distribution maps that
+/// can be returned:
+///
+/// - Type: vector<16x32x64xf32>,
+/// Map: (d0, d1, d2) -> (d1, d2) : Distribute d1, and then d2
+/// - Type: vector<16x32x64xf32>
+/// Map: (d0, d1, d2) -> (d2, d1, d0) : Distribute d2, then d1 and then d0
using DistributionMapFn = std::function<AffineMap(Value)>;
/// Distribute transfer_write ops based on the affine map returned by
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 4f3d93abdf83c69..47aafd8302f3108 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -433,20 +433,31 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
}
/// Return the distributed vector type based on the original type and the
-/// distribution map. The map is expected to have a dimension equal to the
-/// original type rank and should be a projection where the results are the
-/// distributed dimensions. The vector should be completely distributably, i.e.
+/// distribution map. The vector should be completely distributable, i.e.
/// the linearized shape should be a multiple of the warp size.
-/// Example (single-dim): For a vector<16x32x64> distributed with
-/// a map(d0, d1, d2) -> (d1) and a warp size of 16 would distribute the second
-/// dimension (associated to d1) and return vector<16x2x64>.
-/// Example (multi-dim): For a vector<16x32x64> distributed with a
+///
+/// The distribution map represents in what order the dimensions of the vector
+/// should be distributed. The map is expected to be a projected permutation of
+/// the vector shape dimensions. Examples of distribution maps:
+/// - (d0, d1, d2) -> (d1, d2) : Distribute d1, and then d2
+/// - (d0, d1, d2) -> (d2, d1, d0) : Distribute d2, then d1 and then d0
+/// If all threads are used while distributing the first few dimensions, the
+/// rest dimensions may not be used for distribution.
+///
+/// Example (single-dim): For a vector<16x32x64> distributed with a
+/// map(d0, d1, d2) -> (d1) and a warp size of 16 would distribute the second
+/// dimension (associated to d1) and return vector<16x2x64>.
+///
+/// Example (multi-dim): For a vector<16x32x64> distributed with a
/// map(d0, d1, d2) -> (d1, d2), and a warp size of 128 would distribute first
/// the second dimension and then the third dimension, finally returning a
/// vector <4x1x64>.
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
- assert(map.isProjectedPermutation() && "expected projected permutation map");
+ if (!map.isProjectedPermutation()) {
+ assert(false && "expected projected permutation map");
+ return VectorType();
+ }
SmallVector<int64_t> targetShape(originalType.getShape().begin(),
originalType.getShape().end());
@@ -457,14 +468,14 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t &dimSize = targetShape[position];
if (availableThreads > dimSize) {
// We have more threads available than the size of the dimension, so we
- // distribute the whole dimension.
+ // distribute the with size 1 along this dimension.
if (availableThreads % dimSize != 0)
return VectorType();
availableThreads = availableThreads / dimSize;
dimSize = 1;
} else {
// We have the dimension is bigger than the number of threads available,
- // so we distribute a part of the dimension to each thread.
+ // so we distribute with size > 1 along this dimension.
if (dimSize % availableThreads != 0)
return VectorType();
dimSize = dimSize / availableThreads;
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 29ac67ac9a93371..621792efd467bd9 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -504,7 +504,16 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
// CHECK-PROP: }
// CHECK-PROP: }
// CHECK-PROP: vector.reduction <add>
-// CHECK-PROP: gpu.shuffle
+// CHECK-PROP-COUNT=8: gpu.shuffle
+//
+// CHECK-PROP: scf.for {{.*}} {
+// CHECK-PROP: vector.transfer_read
+// CHECK-PROP: scf.for {{.*}} {
+// CHECK-PROP: vector.warp_execute_on_lane_0
+// CHECK-PROP: vector.transfer_read
+// CHECK-PROP: vector.transfer_write
+// CHECK-PROP: }
+// CHECK-PROP: }
#map = affine_map<(d0, d1) -> (0, 0)>
func.func @warp_scf_for_multi_reduce(%arg0: memref<2x32x40x384xf32>, %arg1: memref<2x32x40x384xf16>, %arg2: memref<2x32xf32>, %arg3: memref<2x32x40x384xf16>) {
%cst = arith.constant dense<1.536000e+04> : vector<8x128xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index cf7799a403a3295..ed981fcdfe60d87 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -574,10 +574,10 @@ struct TestVectorDistribution
VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
OpBuilder builder(val.getContext());
- SmallVector<AffineExpr, 4> vecDims = llvm::to_vector<4>(
- llvm::map_range(llvm::seq<int64_t>(0, vecRank), [&](int64_t i) {
+ SmallVector<AffineExpr, 4> vecDims =
+ llvm::map_to_vector(llvm::seq<int64_t>(0, vecRank), [&](int64_t i) {
return builder.getAffineDimExpr(vecRank - i - 1);
- }));
+ });
return AffineMap::get(vecRank, /*symbolCount=*/0, vecDims,
builder.getContext());
};
>From 76f9d11a646ca0b792f18e1767177606255575ab Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 8 Nov 2023 15:21:11 +0530
Subject: [PATCH 4/7] Update test
---
.../Vector/vector-warp-distribute.mlir | 34 +++++++++----------
1 file changed, 17 insertions(+), 17 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 621792efd467bd9..c05e9cafd04323d 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -496,24 +496,24 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
// -----
-// CHECK-PROP-LABEL: func @warp_scf_for_multi_reduce(
-// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
-// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
-// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
-// CHECK-PROP: vector.transfer_read {{.*}} : memref<2x32x40x384xf32>, vector<1x4xf32>
-// CHECK-PROP: }
-// CHECK-PROP: }
-// CHECK-PROP: vector.reduction <add>
-// CHECK-PROP-COUNT=8: gpu.shuffle
+// CHECK-PROP-LABEL: func @warp_scf_for_multi_reduce(
+// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
+// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
+// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
+// CHECK-PROP: vector.transfer_read {{.*}} : memref<2x32x40x384xf32>, vector<1x4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: }
+// CHECK-PROP: vector.reduction <add>
+// CHECK-PROP-COUNT=8: gpu.shuffle
//
-// CHECK-PROP: scf.for {{.*}} {
-// CHECK-PROP: vector.transfer_read
-// CHECK-PROP: scf.for {{.*}} {
-// CHECK-PROP: vector.warp_execute_on_lane_0
-// CHECK-PROP: vector.transfer_read
-// CHECK-PROP: vector.transfer_write
-// CHECK-PROP: }
-// CHECK-PROP: }
+// CHECK-PROP: scf.for {{.*}} {
+// CHECK-PROP: vector.transfer_read
+// CHECK-PROP: scf.for {{.*}} {
+// CHECK-PROP: vector.warp_execute_on_lane_0
+// CHECK-PROP: vector.transfer_read
+// CHECK-PROP: vector.transfer_write
+// CHECK-PROP: }
+// CHECK-PROP: }
#map = affine_map<(d0, d1) -> (0, 0)>
func.func @warp_scf_for_multi_reduce(%arg0: memref<2x32x40x384xf32>, %arg1: memref<2x32x40x384xf16>, %arg2: memref<2x32xf32>, %arg3: memref<2x32x40x384xf16>) {
%cst = arith.constant dense<1.536000e+04> : vector<8x128xf32>
>From 864297ab32af54e535e2698c1f08d875ca2b5af1 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 9 Nov 2023 14:35:20 +0530
Subject: [PATCH 5/7] WIP
---
.../Vector/Transforms/VectorDistribution.h | 17 ++-------------
.../Vector/Transforms/VectorDistribute.cpp | 21 +++++++------------
2 files changed, 10 insertions(+), 28 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index 781a46117602aae..5503ca8ec455a91 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -40,18 +40,6 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
const WarpExecuteOnLane0LoweringOptions &options,
PatternBenefit benefit = 1);
-/// Given a value having a shaped type, returns the distribution map for that
-/// value. The distribution map represents the order of dimensions in which
-/// the shape should be distributed. The map is expected to be a projected
-/// permutation of the shape dimensions. Examples of distribution maps that
-/// can be returned:
-///
-/// - Type: vector<16x32x64xf32>,
-/// Map: (d0, d1, d2) -> (d1, d2) : Distribute d1, and then d2
-/// - Type: vector<16x32x64xf32>
-/// Map: (d0, d1, d2) -> (d2, d1, d0) : Distribute d2, then d1 and then d0
-using DistributionMapFn = std::function<AffineMap(Value)>;
-
/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`.
/// Example:
@@ -69,9 +57,8 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
/// vector.yield %v : vector<32xf32>
/// }
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
-void populateDistributeTransferWriteOpPatterns(
- RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
- PatternBenefit benefit = 1);
+void populateDistributeTransferWriteOpPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Move scalar operations with no dependency on the warp op outside of the
/// region.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 47aafd8302f3108..c995c6633418b06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -433,22 +433,16 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
}
/// Return the distributed vector type based on the original type and the
-/// distribution map. The vector should be completely distributable, i.e.
-/// the linearized shape should be a multiple of the warp size.
+/// distribution map. The vector should be completely distributable, i.e. the
+/// linearized shape should be a multiple of the warp size. If all threads are
+/// used while distributing the first few dimensions, the rest dimensions may
+/// not be used for distribution.
///
-/// The distribution map represents in what order the dimensions of the vector
-/// should be distributed. The map is expected to be a projected permutation of
-/// the vector shape dimensions. Examples of distribution maps:
-/// - (d0, d1, d2) -> (d1, d2) : Distribute d1, and then d2
-/// - (d0, d1, d2) -> (d2, d1, d0) : Distribute d2, then d1 and then d0
-/// If all threads are used while distributing the first few dimensions, the
-/// rest dimensions may not be used for distribution.
-///
-/// Example (single-dim): For a vector<16x32x64> distributed with a
+/// Example (single-dim): For a vector<16x32x64> distributed with a
/// map(d0, d1, d2) -> (d1) and a warp size of 16 would distribute the second
-/// dimension (associated to d1) and return vector<16x2x64>.
+/// dimension (associated to d1) and return vector<16x2x64>.
///
-/// Example (multi-dim): For a vector<16x32x64> distributed with a
+/// Example (multi-dim): For a vector<16x32x64> distributed with a
/// map(d0, d1, d2) -> (d1, d2), and a warp size of 128 would distribute first
/// the second dimension and then the third dimension, finally returning a
/// vector <4x1x64>.
@@ -773,6 +767,7 @@ bool delinearizeLaneId(OpBuilder &builder, Location loc,
ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> distributedShape, int64_t warpSize,
Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
+
// If the original shape and the distributed shape is the same, we don't
// distribute at all--every thread is handling the whole. For such case, we
// should not rely on lane IDs later. So just return an empty lane ID vector.
>From a702da3d4f06d857493a748ac48f660bbd052ffb Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 9 Nov 2023 19:10:41 +0530
Subject: [PATCH 6/7] implicit distribution map
---
.../Vector/Transforms/VectorDistribution.h | 6 +-
.../Vector/Transforms/VectorDistribute.cpp | 136 +++++++-----------
.../Vector/vector-warp-distribute.mlir | 46 ++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 19 +--
4 files changed, 100 insertions(+), 107 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index 5503ca8ec455a91..fbb59abfce5e49f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -40,8 +40,8 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
const WarpExecuteOnLane0LoweringOptions &options,
PatternBenefit benefit = 1);
-/// Distribute transfer_write ops based on the affine map returned by
-/// `distributionMapFn`.
+/// Distribute transfer_write ops based.
+/// TODO: Add documentation here how the distribution is done.
/// Example:
/// ```
/// %0 = vector.warp_execute_on_lane_0(%id){
@@ -73,7 +73,7 @@ using WarpShuffleFromIdxFn =
/// to decide how a value should be distributed when this cannot be inferred
/// from its uses.
void populatePropagateWarpVectorDistributionPatterns(
- RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn,
+ RewritePatternSet &pattern,
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit = 1);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c995c6633418b06..09ed1111df72d39 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -22,33 +22,6 @@
using namespace mlir;
using namespace mlir::vector;
-/// Currently the distribution map is implicit based on the vector shape. In the
-/// future it will be part of the op.
-/// Example:
-/// ```
-/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
-/// ...
-/// vector.yield %3 : vector<32x16x64xf32>
-/// }
-/// ```
-/// Would have an implicit map of:
-/// `(d0, d1, d2) -> (d0, d2)`
-static AffineMap calculateImplicitMap(VectorType sequentialType,
- VectorType distributedType) {
- SmallVector<AffineExpr> perm;
- perm.reserve(1);
- // Check which dimensions of the sequential type are different than the
- // dimensions of the distributed type to know the distributed dimensions. Then
- // associate each distributed dimension to an ID in order.
- for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
- if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
- perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
- }
- auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
- distributedType.getContext());
- return map;
-}
-
namespace {
/// Helper struct to create the load / store operations that permit transit
@@ -63,9 +36,6 @@ struct DistributedLoadStoreHelper {
laneId(laneId), zero(zero) {
sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
- if (sequentialVectorType && distributedVectorType)
- distributionMap =
- calculateImplicitMap(sequentialVectorType, distributedVectorType);
}
Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
@@ -96,10 +66,8 @@ struct DistributedLoadStoreHelper {
int64_t rank = sequentialVectorType.getRank();
SmallVector<Value> indices(rank, zero);
if (val == distributedVal) {
- for (auto dimExpr : distributionMap.getResults()) {
- int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
+ for (auto index : llvm::seq<int64_t>(0, rank))
indices[index] = buildDistributedOffset(b, loc, index);
- }
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferWriteOp>(
@@ -139,12 +107,11 @@ struct DistributedLoadStoreHelper {
assert((type == distributedVectorType || type == sequentialVectorType) &&
"Must store either the preregistered distributed or the "
"preregistered sequential type.");
- SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
+ int64_t rank = sequentialVectorType.getRank();
+ SmallVector<Value> indices(rank, zero);
if (type == distributedVectorType) {
- for (auto dimExpr : distributionMap.getResults()) {
- int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
+ for (auto index : llvm::seq<int64_t>(0, rank))
indices[index] = buildDistributedOffset(b, loc, index);
- }
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>(
@@ -154,7 +121,6 @@ struct DistributedLoadStoreHelper {
Value sequentialVal, distributedVal, laneId, zero;
VectorType sequentialVectorType, distributedVectorType;
- AffineMap distributionMap;
};
} // namespace
@@ -446,20 +412,13 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
/// map(d0, d1, d2) -> (d1, d2), and a warp size of 128 would distribute first
/// the second dimension and then the third dimension, finally returning a
/// vector <4x1x64>.
-static VectorType getDistributedType(VectorType originalType, AffineMap map,
+static VectorType getDistributedType(VectorType originalType,
int64_t warpSize) {
- if (!map.isProjectedPermutation()) {
- assert(false && "expected projected permutation map");
- return VectorType();
- }
-
SmallVector<int64_t> targetShape(originalType.getShape().begin(),
originalType.getShape().end());
// Distribute the vector based on the order of dimensions in the affine map.
int64_t availableThreads = warpSize;
- for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
- unsigned position = map.getDimPosition(i);
- int64_t &dimSize = targetShape[position];
+ for (int64_t &dimSize : targetShape) {
if (availableThreads > dimSize) {
// We have more threads available than the size of the dimension, so we
// distribute the with size 1 along this dimension.
@@ -505,16 +464,22 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
/// }
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
- WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
- PatternBenefit b = 1)
- : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
- distributionMapFn(std::move(fn)) {}
+ WarpOpTransferWrite(MLIRContext *ctx, PatternBenefit b = 1)
+ : OpRewritePattern<vector::TransferWriteOp>(ctx, b) {}
/// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
/// are multiples of the distribution ratio are supported at the moment.
LogicalResult tryDistributeOp(RewriterBase &rewriter,
vector::TransferWriteOp writeOp,
WarpExecuteOnLane0Op warpOp) const {
+
+ // TODO: Distribution of non-trivial projected permutation maps
+ // breaks the implicit assumption that the we distribute the outer most
+ // vector dimensions. To allow for this, we need a global layout analysis
+ // over warp_execute_on_lane_0 instead of the current local analysis.
+ if (!writeOp.getPermutationMap().isMinorIdentityWithBroadcasting())
+ return failure();
+
VectorType writtenVectorType = writeOp.getVectorType();
// 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
@@ -523,9 +488,8 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
return failure();
// 2. Compute the distributed type.
- AffineMap map = distributionMapFn(writeOp.getVector());
VectorType targetType =
- getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
+ getDistributedType(writtenVectorType, warpOp.getWarpSize());
if (!targetType)
return failure();
@@ -541,7 +505,7 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
if (!writeOp.getPermutationMap().isMinorIdentity())
return failure();
maskType =
- getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
+ getDistributedType(writeOp.getMaskType(), warpOp.getWarpSize());
}
// 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
@@ -553,20 +517,19 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
auto newWarpOp =
newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
rewriter.setInsertionPoint(newWriteOp);
- AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
+ AffineMap indexMap = newWriteOp.getPermutationMap();
Location loc = newWriteOp.getLoc();
SmallVector<Value> indices(newWriteOp.getIndices().begin(),
newWriteOp.getIndices().end());
- for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
+ for (AffineExpr it : indexMap.getResults()) {
AffineExpr d0, d1;
bindDims(newWarpOp.getContext(), d0, d1);
- auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+ auto indexExpr = it.dyn_cast<AffineDimExpr>();
if (!indexExpr)
continue;
unsigned indexPos = indexExpr.getPosition();
- unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
auto scale =
- rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
+ rewriter.getAffineConstantExpr(targetType.getDimSize(indexPos));
indices[indexPos] = affine::makeComposedAffineApply(
rewriter, loc, d0 + scale * d1,
{indices[indexPos], newWarpOp.getLaneid()});
@@ -646,9 +609,6 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
return failure();
}
-
-private:
- DistributionMapFn distributionMapFn;
};
/// Sink out elementwise op feeding into a warp op yield.
@@ -837,18 +797,23 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!operand)
return failure();
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
+
+ // TODO: Distribution of non-trivial projected permutation maps
+ // breaks the implicit assumption that the we distribute the outer most
+ // vector dimensions. To allow for this, we need a global layout analysis
+ // over warp_execute_on_lane_0 instead of the current local analysis.
+ if (!read.getPermutationMap().isMinorIdentityWithBroadcasting())
+ return failure();
+
// Don't duplicate transfer_read ops when distributing.
if (!read.getResult().hasOneUse())
return failure();
unsigned operandIndex = operand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex);
- SmallVector<Value, 4> indices(read.getIndices().begin(),
- read.getIndices().end());
auto sequentialType = cast<VectorType>(read.getResult().getType());
auto distributedType = cast<VectorType>(distributedVal.getType());
- AffineMap map = calculateImplicitMap(sequentialType, distributedType);
- AffineMap indexMap = map.compose(read.getPermutationMap());
+ AffineMap indexMap = read.getPermutationMap();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(warpOp);
@@ -860,20 +825,21 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
warpOp.getLaneid(), delinearizedIds))
return rewriter.notifyMatchFailure(
read, "cannot delinearize lane ID for distribution");
- assert(!delinearizedIds.empty() || map.getNumResults() == 0);
+ assert(!delinearizedIds.empty());
- for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
+ SmallVector<Value, 4> indices(read.getIndices().begin(),
+ read.getIndices().end());
+ for (AffineExpr it : indexMap.getResults()) {
AffineExpr d0, d1;
bindDims(read.getContext(), d0, d1);
- auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+ auto indexExpr = it.dyn_cast<AffineDimExpr>();
if (!indexExpr)
continue;
unsigned indexPos = indexExpr.getPosition();
- unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
- int64_t scale = distributedType.getDimSize(vectorPos);
+ auto scale = distributedType.getDimSize(indexPos);
indices[indexPos] = affine::makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
- {indices[indexPos], delinearizedIds[vectorPos]});
+ {indices[indexPos], delinearizedIds[indexPos]});
}
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), distributedVal.getType(), read.getSource(), indices,
@@ -1055,7 +1021,7 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
VectorType castResultType = castDistributedType;
// We expect the distributed type to have a smaller rank than the original
- // type. Prepend with size-one dimensions to make them the same.
+ // type. We distribute the original type according to standard distribution.
unsigned castDistributedRank = castDistributedType.getRank();
unsigned castOriginalRank = castOriginalType.getRank();
if (castDistributedRank < castOriginalRank) {
@@ -1064,6 +1030,8 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
castDistributedType =
VectorType::get(shape, castDistributedType.getElementType());
}
+ castDistributedType =
+ getDistributedType(castOriginalType, warpOp.getWarpSize());
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1508,9 +1476,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
/// ```
struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
- WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
- : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
- distributionMapFn(std::move(fn)) {}
+ WarpOpScfForOp(MLIRContext *ctx, PatternBenefit b = 1)
+ : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
@@ -1535,8 +1502,7 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
return;
Type distType = operand->get().getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
- AffineMap map = distributionMapFn(operand->get());
- distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ distType = getDistributedType(vecType, warpOp.getWarpSize());
}
inputTypes.push_back(operand->get().getType());
distTypes.push_back(distType);
@@ -1627,9 +1593,6 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
mlir::vector::moveScalarUniformCode(innerWarp);
return success();
}
-
-private:
- DistributionMapFn distributionMapFn;
};
/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
@@ -1724,14 +1687,12 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
}
void mlir::vector::populateDistributeTransferWriteOpPatterns(
- RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
- PatternBenefit benefit) {
- patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
- benefit);
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<WarpOpTransferWrite>(patterns.getContext(), benefit);
}
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
- RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
+ RewritePatternSet &patterns,
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
@@ -1739,8 +1700,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
WarpOpInsert>(patterns.getContext(), benefit);
patterns.add<WarpOpExtractElement>(patterns.getContext(),
warpShuffleFromIdxFn, benefit);
- patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
- benefit);
+ patterns.add<WarpOpScfForOp>(patterns.getContext(), benefit);
}
void mlir::vector::populateDistributeReduction(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index c05e9cafd04323d..aad893d54d17567 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -561,6 +561,52 @@ func.func @warp_scf_for_multi_reduce(%arg0: memref<2x32x40x384xf32>, %arg1: memr
// -----
+// CHECK-PROP-LABEL: func @warp_multi_reduce_3d(
+// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
+// CHECK-PROP: vector.transfer_read {{.*}} : memref<128x4x64xf32>, vector<1x2x64xf32>
+// CHECK-PROP: vector.shape_cast {{.*}} : vector<1x2x64xf32> to vector<128xf32>
+// CHECK-PROP: vector.reduction <add>, {{.*}} : vector<128xf32> into f32
+// CHECK-PROP-COUNT=8: gpu.shuffle
+func.func @warp_multi_reduce_3d(%arg0 : memref<128x4x64xf32>) -> f32 {
+ %0 = gpu.thread_id x
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %ret = vector.warp_execute_on_lane_0(%0)[256] -> f32 {
+ %read = vector.transfer_read %arg0[%c0, %c0, %c0], %cst { in_bounds = [true, true, true] } : memref<128x4x64xf32>, vector<128x4x64xf32>
+ %cast = vector.shape_cast %read : vector<128x4x64xf32> to vector<32768xf32>
+ %out = vector.reduction <add>, %cast, %cst : vector<32768xf32> into f32
+ vector.yield %out : f32
+ }
+ func.return %ret : f32
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_multi_dim_diff_read_cast(
+// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
+// CHECK-PROP: vector.transfer_read {{.*}} : memref<2x4x16xf32>, vector<1x2x16xf32>
+// CHECK-PROP: vector.transfer_read {{.*}} : memref<128xf32>, vector<32xf32>
+// CHECK-PROP: vector.shape_cast {{.*}} : vector<1x2x16xf32> to vector<32xf32>
+// CHECK-PROP: arith.addf {{.*}} : vector<32xf32>
+// CHECK-PROP: vector.reduction <add>, {{.*}} : vector<32xf32> into f32
+// CHECK-PROP-COUNT=2: gpu.shuffle
+func.func @warp_multi_dim_diff_read_cast(%arg0 : memref<2x4x16xf32>, %arg1 : memref<128xf32>) -> f32 {
+ %0 = gpu.thread_id x
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %ret = vector.warp_execute_on_lane_0(%0)[4] -> f32 {
+ %read = vector.transfer_read %arg0[%c0, %c0, %c0], %cst { in_bounds = [true, true, true] } : memref<2x4x16xf32>, vector<2x4x16xf32>
+ %read1 = vector.transfer_read %arg1[%c0], %cst { in_bounds = [true] } : memref<128xf32>, vector<128xf32>
+ %cast = vector.shape_cast %read : vector<2x4x16xf32> to vector<128xf32>
+ %added = arith.addf %cast, %read1 : vector<128xf32>
+ %reduced = vector.reduction <add>, %added : vector<128xf32> into f32
+ vector.yield %reduced : f32
+ }
+ func.return %ret : f32
+}
+
+// -----
+
// CHECK-PROP-LABEL: func @vector_reduction(
// CHECK-PROP-SAME: %[[laneid:.*]]: index)
// CHECK-PROP-DAG: %[[c1:.*]] = arith.constant 1 : i32
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ed981fcdfe60d87..9572132c0b2f4a0 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -568,19 +568,6 @@ struct TestVectorDistribution
}
});
MLIRContext *ctx = &getContext();
- auto distributionFn = [](Value val) {
- // Create a map (d0, d1) -> (d1, d0) to distribute starting from the inner
- // dimensions.
- VectorType vecType = dyn_cast<VectorType>(val.getType());
- int64_t vecRank = vecType ? vecType.getRank() : 0;
- OpBuilder builder(val.getContext());
- SmallVector<AffineExpr, 4> vecDims =
- llvm::map_to_vector(llvm::seq<int64_t>(0, vecRank), [&](int64_t i) {
- return builder.getAffineDimExpr(vecRank - i - 1);
- });
- return AffineMap::get(vecRank, /*symbolCount=*/0, vecDims,
- builder.getContext());
- };
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
Value srcIdx, int64_t warpSz) {
assert((val.getType().isF32() || val.getType().isInteger(32)) &&
@@ -598,13 +585,13 @@ struct TestVectorDistribution
};
if (distributeTransferWriteOps) {
RewritePatternSet patterns(ctx);
- populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
+ populateDistributeTransferWriteOpPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
if (propagateDistribution) {
RewritePatternSet patterns(ctx);
- vector::populatePropagateWarpVectorDistributionPatterns(
- patterns, distributionFn, shuffleFn);
+ vector::populatePropagateWarpVectorDistributionPatterns(patterns,
+ shuffleFn);
vector::populateDistributeReduction(patterns, warpReduction);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
>From 22e7ef3f0ec08f4149e9a803c3a67f1de4a365a5 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 9 Nov 2023 19:54:36 +0530
Subject: [PATCH 7/7] Fix bugs to make tests pass
---
.../Vector/Transforms/VectorDistribution.h | 2 +-
.../Vector/Transforms/VectorDistribute.cpp | 41 +++++++++++--------
2 files changed, 26 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index fbb59abfce5e49f..cf104b54ed0500d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -40,7 +40,7 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
const WarpExecuteOnLane0LoweringOptions &options,
PatternBenefit benefit = 1);
-/// Distribute transfer_write ops based.
+/// Distribute transfer_write ops.
/// TODO: Add documentation here how the distribution is done.
/// Example:
/// ```
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 09ed1111df72d39..20f5254b4dd1e30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -66,8 +66,12 @@ struct DistributedLoadStoreHelper {
int64_t rank = sequentialVectorType.getRank();
SmallVector<Value> indices(rank, zero);
if (val == distributedVal) {
- for (auto index : llvm::seq<int64_t>(0, rank))
+ for (auto index : llvm::seq<int64_t>(0, rank)) {
+ if (sequentialVectorType.getDimSize(index) ==
+ distributedVectorType.getDimSize(index))
+ continue;
indices[index] = buildDistributedOffset(b, loc, index);
+ }
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferWriteOp>(
@@ -110,8 +114,12 @@ struct DistributedLoadStoreHelper {
int64_t rank = sequentialVectorType.getRank();
SmallVector<Value> indices(rank, zero);
if (type == distributedVectorType) {
- for (auto index : llvm::seq<int64_t>(0, rank))
+ for (auto index : llvm::seq<int64_t>(0, rank)) {
+ if (sequentialVectorType.getDimSize(index) ==
+ distributedVectorType.getDimSize(index))
+ continue;
indices[index] = buildDistributedOffset(b, loc, index);
+ }
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>(
@@ -521,18 +529,16 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
Location loc = newWriteOp.getLoc();
SmallVector<Value> indices(newWriteOp.getIndices().begin(),
newWriteOp.getIndices().end());
- for (AffineExpr it : indexMap.getResults()) {
+ for (auto [index, it] : llvm::enumerate(indexMap.getResults())) {
AffineExpr d0, d1;
bindDims(newWarpOp.getContext(), d0, d1);
auto indexExpr = it.dyn_cast<AffineDimExpr>();
if (!indexExpr)
continue;
- unsigned indexPos = indexExpr.getPosition();
- auto scale =
- rewriter.getAffineConstantExpr(targetType.getDimSize(indexPos));
- indices[indexPos] = affine::makeComposedAffineApply(
+ auto scale = rewriter.getAffineConstantExpr(targetType.getDimSize(index));
+ indices[index] = affine::makeComposedAffineApply(
rewriter, loc, d0 + scale * d1,
- {indices[indexPos], newWarpOp.getLaneid()});
+ {indices[index], newWarpOp.getLaneid()});
}
newWriteOp.getIndicesMutable().assign(indices);
@@ -746,8 +752,8 @@ bool delinearizeLaneId(OpBuilder &builder, Location loc,
std::multiplies<int64_t>()) != warpSize)
return false;
- AffineExpr s0, s1;
- bindSymbols(builder.getContext(), s0, s1);
+ AffineExpr s0;
+ bindSymbols(builder.getContext(), s0);
int64_t usedThreads = 1;
@@ -825,21 +831,24 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
warpOp.getLaneid(), delinearizedIds))
return rewriter.notifyMatchFailure(
read, "cannot delinearize lane ID for distribution");
- assert(!delinearizedIds.empty());
SmallVector<Value, 4> indices(read.getIndices().begin(),
read.getIndices().end());
- for (AffineExpr it : indexMap.getResults()) {
+ indexMap.dump();
+ distributedType.dump();
+ for (auto [index, it] : llvm::enumerate(indexMap.getResults())) {
AffineExpr d0, d1;
bindDims(read.getContext(), d0, d1);
auto indexExpr = it.dyn_cast<AffineDimExpr>();
if (!indexExpr)
continue;
- unsigned indexPos = indexExpr.getPosition();
- auto scale = distributedType.getDimSize(indexPos);
- indices[indexPos] = affine::makeComposedAffineApply(
+ // If the dimension is not distributed, we don't need to change indexing.
+ if (sequentialType.getDimSize(index) == distributedType.getDimSize(index))
+ continue;
+ auto scale = distributedType.getDimSize(index);
+ indices[index] = affine::makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
- {indices[indexPos], delinearizedIds[indexPos]});
+ {indices[index], delinearizedIds[index]});
}
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), distributedVal.getType(), read.getSource(), indices,
More information about the Mlir-commits
mailing list