[Mlir-commits] [mlir] c2b9529 - [mlir][vector] Fix n-d transfer write distribution (#83215)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 27 21:11:32 PST 2024


Author: Quinn Dawkins
Date: 2024-02-28T00:11:28-05:00
New Revision: c2b952926fe8707527cf1b8bab211dc4c7ab9aee

URL: https://github.com/llvm/llvm-project/commit/c2b952926fe8707527cf1b8bab211dc4c7ab9aee
DIFF: https://github.com/llvm/llvm-project/commit/c2b952926fe8707527cf1b8bab211dc4c7ab9aee.diff

LOG: [mlir][vector] Fix n-d transfer write distribution (#83215)

Currently n-d transfer write distribution can be inconsistent with
distribution of reductions if a value has multiple users, one of which
is a transfer_write with a non-standard distribution map, and the other
of which is a vector.reduction.

We may want to consider removing the distribution map functionality in
the future for this reason.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 620ceee48b196d..b3ab4a916121e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -443,15 +443,24 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
 /// d1) and return vector<16x2x64>
 static VectorType getDistributedType(VectorType originalType, AffineMap map,
                                      int64_t warpSize) {
-  if (map.getNumResults() != 1)
-    return VectorType();
   SmallVector<int64_t> targetShape(originalType.getShape().begin(),
                                    originalType.getShape().end());
   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
     unsigned position = map.getDimPosition(i);
-    if (targetShape[position] % warpSize != 0)
-      return VectorType();
+    if (targetShape[position] % warpSize != 0) {
+      if (warpSize % targetShape[position] != 0) {
+        return VectorType();
+      }
+      warpSize /= targetShape[position];
+      targetShape[position] = 1;
+      continue;
+    }
     targetShape[position] = targetShape[position] / warpSize;
+    warpSize = 1;
+    break;
+  }
+  if (warpSize != 1) {
+    return VectorType();
   }
   VectorType targetType =
       VectorType::get(targetShape, originalType.getElementType());
@@ -526,7 +535,30 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // 4. Reindex the write using the distribution map.
     auto newWarpOp =
         newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
+
+    // Delinearize the lane id based on the way threads are divided across the
+    // vector. To get the number of threads per vector dimension, divide the
+    // sequential size by the distributed size along each dim.
     rewriter.setInsertionPoint(newWriteOp);
+    SmallVector<OpFoldResult> delinearizedIdSizes;
+    for (auto [seqSize, distSize] :
+         llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
+      assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
+      delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
+    }
+    SmallVector<Value> delinearized;
+    if (map.getNumResults() > 1) {
+      delinearized = rewriter
+                         .create<mlir::affine::AffineDelinearizeIndexOp>(
+                             newWarpOp.getLoc(), newWarpOp.getLaneid(),
+                             delinearizedIdSizes)
+                         .getResults();
+    } else {
+      // If there is only one map result, we can elide the delinearization
+      // op and use the lane id directly.
+      delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
+    }
+
     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
     Location loc = newWriteOp.getLoc();
     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
@@ -539,11 +571,11 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
         continue;
       unsigned indexPos = indexExpr.getPosition();
       unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
+      Value laneId = delinearized[vectorPos];
       auto scale =
           rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
       indices[indexPos] = affine::makeComposedAffineApply(
-          rewriter, loc, d0 + scale * d1,
-          {indices[indexPos], newWarpOp.getLaneid()});
+          rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
     }
     newWriteOp.getIndicesMutable().assign(indices);
 

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 9072603734879e..bf90c4a6ebb3c2 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1559,3 +1559,28 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
 //       CHECK-PROP:   %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
 //       CHECK-PROP:   %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
 //       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
+
+// -----
+
+func.func @warp_propagate_nd_write(%laneid: index, %dest: memref<4x1024xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.warp_execute_on_lane_0(%laneid)[32] -> () {
+    %0 = "some_def"() : () -> (vector<4x1024xf32>)
+    vector.transfer_write %0, %dest[%c0, %c0] : vector<4x1024xf32>, memref<4x1024xf32>
+    vector.yield
+  }
+  return
+}
+
+//       CHECK-DIST-AND-PROP: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 128)>
+
+// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_nd_write(
+//       CHECK-DIST-AND-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1x128xf32>) {
+//       CHECK-DIST-AND-PROP:     %[[V0:.*]] = "some_def"
+//       CHECK-DIST-AND-PROP:     vector.yield %[[V0]]
+//  CHECK-DIST-AND-PROP-SAME:       vector<4x1024xf32>
+//       CHECK-DIST-AND-PROP:   }
+
+//       CHECK-DIST-AND-PROP:   %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (%c4, %c8) : index, index
+//       CHECK-DIST-AND-PROP:   %[[INNER_ID:.+]] = affine.apply #map()[%[[IDS]]#1]
+//       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]], %{{.*}}[%[[IDS]]#0, %[[INNER_ID]]] {{.*}} : vector<1x128xf32>

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 178a58e796b246..915f713f7047be 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -630,15 +630,13 @@ struct TestVectorDistribution
     });
     MLIRContext *ctx = &getContext();
     auto distributionFn = [](Value val) {
-      // Create a map (d0, d1) -> (d1) to distribute along the inner
-      // dimension. Once we support n-d distribution we can add more
-      // complex cases.
+      // Create an identity dim map of the same rank as the vector.
       VectorType vecType = dyn_cast<VectorType>(val.getType());
       int64_t vecRank = vecType ? vecType.getRank() : 0;
       OpBuilder builder(val.getContext());
       if (vecRank == 0)
         return AffineMap::get(val.getContext());
-      return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+      return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
     };
     auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
                         Value srcIdx, int64_t warpSz) {


        


More information about the Mlir-commits mailing list