[Mlir-commits] [mlir] [mlir][vector] Add support for multi-dim reduction vector distribution (PR #71193)
Kunwar Grover
llvmlistbot at llvm.org
Wed Nov 8 02:24:18 PST 2023
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/71193
>From 2a45381d24a605c644028f613a131956eebb8268 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 3 Nov 2023 21:12:01 +0530
Subject: [PATCH 1/5] [mlir][vector] Add support for multi-dim reduction vector
distribution
---
.../Vector/Transforms/VectorDistribute.cpp | 49 +++++++++++++++----
1 file changed, 39 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8b4575e96875409..13648932cf7b8f0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -425,23 +425,48 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
/// Return the distributed vector type based on the original type and the
/// distribution map. The map is expected to have a dimension equal to the
/// original type rank and should be a projection where the results are the
-/// distributed dimensions. The number of results should be equal to the number
-/// of warp sizes which is currently limited to 1.
-/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
-/// and a warp size of 16 would distribute the second dimension (associated to
-/// d1) and return vector<16x2x64>
+/// distributed dimensions. The vector should be completely distributably, i.e.
+/// the linearized shape should be a multiple of the warp size.
+/// Example (single-dim): For a vector<16x32x64> distributed with
+/// a map(d0, d1, d2) -> (d1) and a warp size of 16 would distribute the second
+/// dimension (associated to d1) and return vector<16x2x64>.
+/// Example (multi-dim): For a vector<16x32x64> distributed with a
+/// map(d0, d1, d2) -> (d1, d2), and a warp size of 128 would distribute first
+/// the second dimension and then the third dimension, finally returning a
+/// vector <4x1x64>.
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
- if (map.getNumResults() != 1)
- return VectorType();
+ assert(map.isProjectedPermutation() && "expected projected permutation map");
+
SmallVector<int64_t> targetShape(originalType.getShape().begin(),
originalType.getShape().end());
+ // Distribute the vector based on the order of dimensions in the affine map.
+ int64_t availableThreads = warpSize;
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
- if (targetShape[position] % warpSize != 0)
- return VectorType();
- targetShape[position] = targetShape[position] / warpSize;
+ int64_t &dimSize = targetShape[position];
+ if (availableThreads > dimSize) {
+ // We have more threads available than the size of the dimension, so we
+ // distribute the whole dimension.
+ if (availableThreads % dimSize != 0)
+ return VectorType();
+ availableThreads = availableThreads / dimSize;
+ dimSize = 1;
+ } else {
+ // We have the dimension is bigger than the number of threads available,
+ // so we distribute a part of the dimension to each thread.
+ if (dimSize % availableThreads != 0)
+ return VectorType();
+ dimSize = dimSize / availableThreads;
+ availableThreads = 1;
+ break;
+ }
}
+
+ // If we could not distribute the whole vector, we fail.
+ if (availableThreads != 1)
+ return VectorType();
+
VectorType targetType =
VectorType::get(targetShape, originalType.getElementType());
return targetType;
@@ -1485,6 +1510,10 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
});
+ // Check if any types could not be distributed.
+ if (llvm::any_of(distTypes, [](Type t) { return !t; }))
+ return failure();
+
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
>From 8bbe490538fca9686a4aea7b91c6912a27323902 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sat, 4 Nov 2023 12:28:01 +0530
Subject: [PATCH 2/5] Add tests
---
.../Vector/vector-warp-distribute.mlir | 56 +++++++++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 14 +++--
2 files changed, 64 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 3bb981c7a623886..67bfc6b47fe29c0 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -494,6 +494,62 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
// -----
+// CHECK-PROP-LABEL: func @warp_scf_for_multi_reduce(
+// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
+// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
+// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
+// CHECK-PROP: vector.transfer_read {{.*}} : memref<2x32x40x384xf32>, vector<1x4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: }
+// CHECK-PROP: vector.reduction <add>
+// CHECK-PROP: gpu.shuffle
+#map = affine_map<(d0, d1) -> (0, 0)>
+func.func @warp_scf_for_multi_reduce(%arg0: memref<2x32x40x384xf32>, %arg1: memref<2x32x40x384xf16>, %arg2: memref<2x32xf32>, %arg3: memref<2x32x40x384xf16>) {
+ %cst = arith.constant dense<1.536000e+04> : vector<8x128xf32>
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<8x128xf32>
+ %cst_1 = arith.constant 9.99999997E-7 : f32
+ %c128 = arith.constant 128 : index
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c40 = arith.constant 40 : index
+ %c384 = arith.constant 384 : index
+ %cst_2 = arith.constant 0.000000e+00 : f16
+ %cst_3 = arith.constant 0.000000e+00 : f32
+ %0 = gpu.thread_id x
+ %1 = arith.truncf %cst_1 : f32 to f16
+ vector.warp_execute_on_lane_0(%0)[256] {
+ %2 = scf.for %arg4 = %c0 to %c40 step %c8 iter_args(%arg5 = %cst_0) -> (vector<8x128xf32>) {
+ %11 = scf.for %arg6 = %c0 to %c384 step %c128 iter_args(%arg7 = %arg5) -> (vector<8x128xf32>) {
+ %12 = vector.transfer_read %arg0[%c0, %c0, %arg4, %arg6], %cst_3 {in_bounds = [true, true]} : memref<2x32x40x384xf32>, vector<8x128xf32>
+ %13 = arith.addf %12, %arg7 : vector<8x128xf32>
+ scf.yield %13 : vector<8x128xf32>
+ }
+ scf.yield %11 : vector<8x128xf32>
+ }
+ %3 = vector.shape_cast %2 : vector<8x128xf32> to vector<1024xf32>
+ %4 = vector.reduction <add>, %3, %cst_3 : vector<1024xf32> into f32
+ %5 = vector.broadcast %4 : f32 to vector<8x128xf32>
+ %6 = arith.divf %5, %cst : vector<8x128xf32>
+ %7 = arith.truncf %6 : vector<8x128xf32> to vector<8x128xf16>
+ %8 = vector.broadcast %1 : f16 to vector<8x128xf16>
+ %9 = arith.addf %7, %8 : vector<8x128xf16>
+ %10 = math.rsqrt %9 : vector<8x128xf16>
+ scf.for %arg4 = %c0 to %c40 step %c8 {
+ %11 = vector.transfer_read %arg2[%c0, %c0], %cst_3 {in_bounds = [true, true], permutation_map = #map} : memref<2x32xf32>, vector<8x128xf32>
+ %12 = arith.truncf %11 : vector<8x128xf32> to vector<8x128xf16>
+ scf.for %arg5 = %c0 to %c384 step %c128 {
+ %13 = vector.transfer_read %arg1[%c0, %c0, %arg4, %arg5], %cst_2 {in_bounds = [true, true]} : memref<2x32x40x384xf16>, vector<8x128xf16>
+ %14 = arith.subf %13, %12 : vector<8x128xf16>
+ %15 = arith.mulf %14, %10 : vector<8x128xf16>
+ vector.transfer_write %15, %arg3[%c0, %c0, %arg4, %arg5] {in_bounds = [true, true]} : vector<8x128xf16>, memref<2x32x40x384xf16>
+ }
+ }
+ }
+ return
+}
+
+// -----
+
// CHECK-PROP-LABEL: func @vector_reduction(
// CHECK-PROP-SAME: %[[laneid:.*]]: index)
// CHECK-PROP-DAG: %[[c1:.*]] = arith.constant 1 : i32
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 2fbf1babf437f08..cf7799a403a3295 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -569,15 +569,17 @@ struct TestVectorDistribution
});
MLIRContext *ctx = &getContext();
auto distributionFn = [](Value val) {
- // Create a map (d0, d1) -> (d1) to distribute along the inner
- // dimension. Once we support n-d distribution we can add more
- // complex cases.
+ // Create a map (d0, d1) -> (d1, d0) to distribute starting from the inner
+ // dimensions.
VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
OpBuilder builder(val.getContext());
- if (vecRank == 0)
- return AffineMap::get(val.getContext());
- return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+ SmallVector<AffineExpr, 4> vecDims = llvm::to_vector<4>(
+ llvm::map_range(llvm::seq<int64_t>(0, vecRank), [&](int64_t i) {
+ return builder.getAffineDimExpr(vecRank - i - 1);
+ }));
+ return AffineMap::get(vecRank, /*symbolCount=*/0, vecDims,
+ builder.getContext());
};
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
Value srcIdx, int64_t warpSz) {
>From 3fd4e30cb5c010cff1a415f005a91ab5bef238ed Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 8 Nov 2023 11:13:14 +0530
Subject: [PATCH 3/5] Address comments
---
.../Vector/Transforms/VectorDistribution.h | 10 ++++++
.../Vector/Transforms/VectorDistribute.cpp | 31 +++++++++++++------
.../Vector/vector-warp-distribute.mlir | 11 ++++++-
.../Dialect/Vector/TestVectorTransforms.cpp | 6 ++--
4 files changed, 44 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index a76a58eb5ec6d3c..781a46117602aae 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -40,6 +40,16 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
const WarpExecuteOnLane0LoweringOptions &options,
PatternBenefit benefit = 1);
+/// Given a value having a shaped type, returns the distribution map for that
+/// value. The distribution map represents the order of dimensions in which
+/// the shape should be distributed. The map is expected to be a projected
+/// permutation of the shape dimensions. Examples of distribution maps that
+/// can be returned:
+///
+/// - Type: vector<16x32x64xf32>,
+/// Map: (d0, d1, d2) -> (d1, d2) : Distribute d1, and then d2
+/// - Type: vector<16x32x64xf32>
+/// Map: (d0, d1, d2) -> (d2, d1, d0) : Distribute d2, then d1 and then d0
using DistributionMapFn = std::function<AffineMap(Value)>;
/// Distribute transfer_write ops based on the affine map returned by
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 13648932cf7b8f0..1359b7a1d3feb6f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -423,20 +423,31 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
}
/// Return the distributed vector type based on the original type and the
-/// distribution map. The map is expected to have a dimension equal to the
-/// original type rank and should be a projection where the results are the
-/// distributed dimensions. The vector should be completely distributably, i.e.
+/// distribution map. The vector should be completely distributable, i.e.
/// the linearized shape should be a multiple of the warp size.
-/// Example (single-dim): For a vector<16x32x64> distributed with
-/// a map(d0, d1, d2) -> (d1) and a warp size of 16 would distribute the second
-/// dimension (associated to d1) and return vector<16x2x64>.
-/// Example (multi-dim): For a vector<16x32x64> distributed with a
+///
+/// The distribution map represents in what order the dimensions of the vector
+/// should be distributed. The map is expected to be a projected permutation of
+/// the vector shape dimensions. Examples of distribution maps:
+/// - (d0, d1, d2) -> (d1, d2) : Distribute d1, and then d2
+/// - (d0, d1, d2) -> (d2, d1, d0) : Distribute d2, then d1 and then d0
+/// If all threads are used while distributing the first few dimensions, the
+/// rest dimensions may not be used for distribution.
+///
+/// Example (single-dim): For a vector<16x32x64> distributed with a
+/// map(d0, d1, d2) -> (d1) and a warp size of 16 would distribute the second
+/// dimension (associated to d1) and return vector<16x2x64>.
+///
+/// Example (multi-dim): For a vector<16x32x64> distributed with a
/// map(d0, d1, d2) -> (d1, d2), and a warp size of 128 would distribute first
/// the second dimension and then the third dimension, finally returning a
/// vector <4x1x64>.
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
- assert(map.isProjectedPermutation() && "expected projected permutation map");
+ if (!map.isProjectedPermutation()) {
+ assert(false && "expected projected permutation map");
+ return VectorType();
+ }
SmallVector<int64_t> targetShape(originalType.getShape().begin(),
originalType.getShape().end());
@@ -447,14 +458,14 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t &dimSize = targetShape[position];
if (availableThreads > dimSize) {
// We have more threads available than the size of the dimension, so we
- // distribute the whole dimension.
+ // distribute the with size 1 along this dimension.
if (availableThreads % dimSize != 0)
return VectorType();
availableThreads = availableThreads / dimSize;
dimSize = 1;
} else {
// We have the dimension is bigger than the number of threads available,
- // so we distribute a part of the dimension to each thread.
+ // so we distribute with size > 1 along this dimension.
if (dimSize % availableThreads != 0)
return VectorType();
dimSize = dimSize / availableThreads;
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 67bfc6b47fe29c0..2eec22fd26cea78 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -502,7 +502,16 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
// CHECK-PROP: }
// CHECK-PROP: }
// CHECK-PROP: vector.reduction <add>
-// CHECK-PROP: gpu.shuffle
+// CHECK-PROP-COUNT=8: gpu.shuffle
+//
+// CHECK-PROP: scf.for {{.*}} {
+// CHECK-PROP: vector.transfer_read
+// CHECK-PROP: scf.for {{.*}} {
+// CHECK-PROP: vector.warp_execute_on_lane_0
+// CHECK-PROP: vector.transfer_read
+// CHECK-PROP: vector.transfer_write
+// CHECK-PROP: }
+// CHECK-PROP: }
#map = affine_map<(d0, d1) -> (0, 0)>
func.func @warp_scf_for_multi_reduce(%arg0: memref<2x32x40x384xf32>, %arg1: memref<2x32x40x384xf16>, %arg2: memref<2x32xf32>, %arg3: memref<2x32x40x384xf16>) {
%cst = arith.constant dense<1.536000e+04> : vector<8x128xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index cf7799a403a3295..ed981fcdfe60d87 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -574,10 +574,10 @@ struct TestVectorDistribution
VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
OpBuilder builder(val.getContext());
- SmallVector<AffineExpr, 4> vecDims = llvm::to_vector<4>(
- llvm::map_range(llvm::seq<int64_t>(0, vecRank), [&](int64_t i) {
+ SmallVector<AffineExpr, 4> vecDims =
+ llvm::map_to_vector(llvm::seq<int64_t>(0, vecRank), [&](int64_t i) {
return builder.getAffineDimExpr(vecRank - i - 1);
- }));
+ });
return AffineMap::get(vecRank, /*symbolCount=*/0, vecDims,
builder.getContext());
};
>From 751f06d907e17360fbd06292525c6cd4b02de645 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 8 Nov 2023 15:21:11 +0530
Subject: [PATCH 4/5] Update test
---
.../Vector/vector-warp-distribute.mlir | 34 +++++++++----------
1 file changed, 17 insertions(+), 17 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 2eec22fd26cea78..e668caf889563ee 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -494,24 +494,24 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
// -----
-// CHECK-PROP-LABEL: func @warp_scf_for_multi_reduce(
-// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
-// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
-// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
-// CHECK-PROP: vector.transfer_read {{.*}} : memref<2x32x40x384xf32>, vector<1x4xf32>
-// CHECK-PROP: }
-// CHECK-PROP: }
-// CHECK-PROP: vector.reduction <add>
-// CHECK-PROP-COUNT=8: gpu.shuffle
+// CHECK-PROP-LABEL: func @warp_scf_for_multi_reduce(
+// CHECK-PROP-NOT: vector.warp_execute_on_lane_0
+// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
+// CHECK-PROP: scf.for {{.*}} -> (vector<1x4xf32>) {
+// CHECK-PROP: vector.transfer_read {{.*}} : memref<2x32x40x384xf32>, vector<1x4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: }
+// CHECK-PROP: vector.reduction <add>
+// CHECK-PROP-COUNT=8: gpu.shuffle
//
-// CHECK-PROP: scf.for {{.*}} {
-// CHECK-PROP: vector.transfer_read
-// CHECK-PROP: scf.for {{.*}} {
-// CHECK-PROP: vector.warp_execute_on_lane_0
-// CHECK-PROP: vector.transfer_read
-// CHECK-PROP: vector.transfer_write
-// CHECK-PROP: }
-// CHECK-PROP: }
+// CHECK-PROP: scf.for {{.*}} {
+// CHECK-PROP: vector.transfer_read
+// CHECK-PROP: scf.for {{.*}} {
+// CHECK-PROP: vector.warp_execute_on_lane_0
+// CHECK-PROP: vector.transfer_read
+// CHECK-PROP: vector.transfer_write
+// CHECK-PROP: }
+// CHECK-PROP: }
#map = affine_map<(d0, d1) -> (0, 0)>
func.func @warp_scf_for_multi_reduce(%arg0: memref<2x32x40x384xf32>, %arg1: memref<2x32x40x384xf16>, %arg2: memref<2x32xf32>, %arg3: memref<2x32x40x384xf16>) {
%cst = arith.constant dense<1.536000e+04> : vector<8x128xf32>
>From 078a33ce63328c6ad8ba67c22f9ee82c3f5192c8 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 8 Nov 2023 15:53:50 +0530
Subject: [PATCH 5/5] Address comments and fix things
---
.../Vector/Transforms/VectorDistribution.h | 23 ++++++++---------
mlir/include/mlir/IR/AffineMap.h | 5 ++++
.../Vector/Transforms/VectorDistribute.cpp | 25 ++++++++-----------
mlir/lib/IR/AffineMap.cpp | 21 ++++++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 2 +-
5 files changed, 48 insertions(+), 28 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index 781a46117602aae..3b1ae34a3acdac9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -15,6 +15,17 @@ namespace mlir {
class RewritePatternSet;
namespace vector {
+/// Given a value having a shaped type, returns the distribution map for that
+/// value. The distribution map represents the order of dimensions in which
+/// the shape should be distributed. The map is expected to be a projection of
+/// the shape dimensions. Examples of distribution maps that can be returned:
+///
+/// - Type: vector<16x32x64xf32>,
+/// Map: (d0, d1, d2) -> (d1, d2) : Distribute d1, and then d2
+/// - Type: vector<16x32x64xf32>
+/// Map: (d0, d1, d2) -> (d0, d1, d2) : Distribute d0, then d1 and then d2
+using DistributionMapFn = std::function<AffineMap(Value)>;
+
struct WarpExecuteOnLane0LoweringOptions {
/// Lamdba function to let users allocate memory needed for the lowering of
/// WarpExecuteOnLane0Op.
@@ -40,18 +51,6 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
const WarpExecuteOnLane0LoweringOptions &options,
PatternBenefit benefit = 1);
-/// Given a value having a shaped type, returns the distribution map for that
-/// value. The distribution map represents the order of dimensions in which
-/// the shape should be distributed. The map is expected to be a projected
-/// permutation of the shape dimensions. Examples of distribution maps that
-/// can be returned:
-///
-/// - Type: vector<16x32x64xf32>,
-/// Map: (d0, d1, d2) -> (d1, d2) : Distribute d1, and then d2
-/// - Type: vector<16x32x64xf32>
-/// Map: (d0, d1, d2) -> (d2, d1, d0) : Distribute d2, then d1 and then d0
-using DistributionMapFn = std::function<AffineMap(Value)>;
-
/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`.
/// Example:
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 5af7835258f6bd2..b78a6c45360580a 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -330,6 +330,11 @@ class AffineMap {
/// returns the resulting values. `this` must be symbol-less.
SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values) const;
+ /// Returns true if the AffineMap represents a subset (i.e. a projection) of
+ /// a symbol-less identity map. `allowZeroInResults` allows projected maps
+ /// with constant zero result expressions.
+ bool isProjection() const;
+
/// Returns true if the AffineMap represents a subset (i.e. a projection) of a
/// symbol-less permutation map. `allowZeroInResults` allows projected
/// permutation maps with constant zero result expressions.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 1359b7a1d3feb6f..70353cf19e07d14 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -423,29 +423,23 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
}
/// Return the distributed vector type based on the original type and the
-/// distribution map. The vector should be completely distributable, i.e.
-/// the linearized shape should be a multiple of the warp size.
+/// distribution map. The vector should be completely distributable, i.e. the
+/// linearized shape should be a multiple of the warp size. If all threads are
+/// used while distributing the first few dimensions, the rest dimensions may
+/// not be used for distribution.
///
-/// The distribution map represents in what order the dimensions of the vector
-/// should be distributed. The map is expected to be a projected permutation of
-/// the vector shape dimensions. Examples of distribution maps:
-/// - (d0, d1, d2) -> (d1, d2) : Distribute d1, and then d2
-/// - (d0, d1, d2) -> (d2, d1, d0) : Distribute d2, then d1 and then d0
-/// If all threads are used while distributing the first few dimensions, the
-/// rest dimensions may not be used for distribution.
-///
-/// Example (single-dim): For a vector<16x32x64> distributed with a
+/// Example (single-dim): For a vector<16x32x64> distributed with a
/// map(d0, d1, d2) -> (d1) and a warp size of 16 would distribute the second
-/// dimension (associated to d1) and return vector<16x2x64>.
+/// dimension (associated to d1) and return vector<16x2x64>.
///
-/// Example (multi-dim): For a vector<16x32x64> distributed with a
+/// Example (multi-dim): For a vector<16x32x64> distributed with a
/// map(d0, d1, d2) -> (d1, d2), and a warp size of 128 would distribute first
/// the second dimension and then the third dimension, finally returning a
/// vector <4x1x64>.
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
- if (!map.isProjectedPermutation()) {
- assert(false && "expected projected permutation map");
+ if (!map.isProjection()) {
+ assert(false && "expected distribution map to be a projection");
return VectorType();
}
@@ -746,6 +740,7 @@ bool delinearizeLaneId(OpBuilder &builder, Location loc,
ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> distributedShape, int64_t warpSize,
Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
+
// If the original shape and the distributed shape is the same, we don't
// distribute at all--every thread is handling the whole. For such case, we
// should not rely on lane IDs later. So just return an empty lane ID vector.
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 3bd1181b6c7bbd8..a0c9d908833b882 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -532,6 +532,27 @@ SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
return res;
}
+bool AffineMap::isProjection() const {
+ if (getNumSymbols() > 0)
+ return false;
+
+ // A projection cannot have more results than inputs.
+ if (getNumResults() > getNumInputs())
+ return false;
+
+ int64_t current = -1;
+ // A projection must always have dim position > current.
+ for (auto expr : getResults()) {
+ if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
+ if (dim.getPosition() <= current)
+ return false;
+ current = dim.getPosition();
+ }
+ }
+
+ return true;
+}
+
bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
if (getNumSymbols() > 0)
return false;
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ed981fcdfe60d87..b996d87be396077 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -576,7 +576,7 @@ struct TestVectorDistribution
OpBuilder builder(val.getContext());
SmallVector<AffineExpr, 4> vecDims =
llvm::map_to_vector(llvm::seq<int64_t>(0, vecRank), [&](int64_t i) {
- return builder.getAffineDimExpr(vecRank - i - 1);
+ return builder.getAffineDimExpr(i);
});
return AffineMap::get(vecRank, /*symbolCount=*/0, vecDims,
builder.getContext());
More information about the Mlir-commits
mailing list