[Mlir-commits] [mlir] [mlir][vector] Add support for multi-dim reduction vector distribution (PR #71193)

Kunwar Grover llvmlistbot at llvm.org
Fri Nov 3 08:53:48 PDT 2023


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/71193

None

>From 2a45381d24a605c644028f613a131956eebb8268 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 3 Nov 2023 21:12:01 +0530
Subject: [PATCH] [mlir][vector] Add support for multi-dim reduction vector
 distribution

---
 .../Vector/Transforms/VectorDistribute.cpp    | 49 +++++++++++++++----
 1 file changed, 39 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8b4575e96875409..13648932cf7b8f0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -425,23 +425,48 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
 /// Return the distributed vector type based on the original type and the
 /// distribution map. The map is expected to have a dimension equal to the
 /// original type rank and should be a projection where the results are the
-/// distributed dimensions. The number of results should be equal to the number
-/// of warp sizes which is currently limited to 1.
-/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
-/// and a warp size of 16 would distribute the second dimension (associated to
-/// d1) and return vector<16x2x64>
+/// distributed dimensions. The vector should be completely distributably, i.e.
+/// the linearized shape should be a multiple of the warp size.
+/// Example (single-dim): For a vector<16x32x64> distributed with
+/// a map(d0, d1, d2) -> (d1) and a warp size of 16 would distribute the second
+/// dimension (associated to d1) and return vector<16x2x64>.
+/// Example (multi-dim): For a vector<16x32x64> distributed with a
+/// map(d0, d1, d2) -> (d1, d2), and a warp size of 128 would distribute first
+/// the second dimension and then the third dimension, finally returning a
+/// vector <4x1x64>.
 static VectorType getDistributedType(VectorType originalType, AffineMap map,
                                      int64_t warpSize) {
-  if (map.getNumResults() != 1)
-    return VectorType();
+  assert(map.isProjectedPermutation() && "expected projected permutation map");
+
   SmallVector<int64_t> targetShape(originalType.getShape().begin(),
                                    originalType.getShape().end());
+  // Distribute the vector based on the order of dimensions in the affine map.
+  int64_t availableThreads = warpSize;
   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
     unsigned position = map.getDimPosition(i);
-    if (targetShape[position] % warpSize != 0)
-      return VectorType();
-    targetShape[position] = targetShape[position] / warpSize;
+    int64_t &dimSize = targetShape[position];
+    if (availableThreads > dimSize) {
+      // We have more threads available than the size of the dimension, so we
+      // distribute the whole dimension.
+      if (availableThreads % dimSize != 0)
+        return VectorType();
+      availableThreads = availableThreads / dimSize;
+      dimSize = 1;
+    } else {
+      // We have the dimension is bigger than the number of threads available,
+      // so we distribute a part of the dimension to each thread.
+      if (dimSize % availableThreads != 0)
+        return VectorType();
+      dimSize = dimSize / availableThreads;
+      availableThreads = 1;
+      break;
+    }
   }
+
+  // If we could not distribute the whole vector, we fail.
+  if (availableThreads != 1)
+    return VectorType();
+
   VectorType targetType =
       VectorType::get(targetShape, originalType.getElementType());
   return targetType;
@@ -1485,6 +1510,10 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
           }
         });
 
+    // Check if any types could not be distributed.
+    if (llvm::any_of(distTypes, [](Type t) { return !t; }))
+      return failure();
+
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, escapingValues.getArrayRef(), distTypes,



More information about the Mlir-commits mailing list