[Mlir-commits] [mlir] [mlir][vector] Fix n-d transfer write distribution (PR #83215)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 27 18:10:28 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

Currently n-d transfer write distribution can be inconsistent with distribution of reductions if a value has multiple users, one of which is a transfer_write with a non-standard distribution map, and the other of which is a vector.reduction.

We may want to consider removing the distribution map functionality in the future for this reason.

---
Full diff: https://github.com/llvm/llvm-project/pull/83215.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+33-6) 
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+25) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+1-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 620ceee48b196d..34f25d2312ef8f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -443,15 +443,24 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
 /// d1) and return vector<16x2x64>
 static VectorType getDistributedType(VectorType originalType, AffineMap map,
                                      int64_t warpSize) {
-  if (map.getNumResults() != 1)
-    return VectorType();
   SmallVector<int64_t> targetShape(originalType.getShape().begin(),
                                    originalType.getShape().end());
   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
     unsigned position = map.getDimPosition(i);
-    if (targetShape[position] % warpSize != 0)
-      return VectorType();
+    if (targetShape[position] % warpSize != 0) {
+      if (warpSize % targetShape[position] != 0) {
+        return VectorType();
+      }
+      warpSize /= targetShape[position];
+      targetShape[position] = 1;
+      continue;
+    }
     targetShape[position] = targetShape[position] / warpSize;
+    warpSize = 1;
+    break;
+  }
+  if (warpSize != 1) {
+    return VectorType();
   }
   VectorType targetType =
       VectorType::get(targetShape, originalType.getElementType());
@@ -526,7 +535,25 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // 4. Reindex the write using the distribution map.
     auto newWarpOp =
         newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
+
     rewriter.setInsertionPoint(newWriteOp);
+    SmallVector<OpFoldResult> delinearizedIdSizes;
+    for (auto [seqSize, distSize] :
+         llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
+      assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
+      delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
+    }
+    SmallVector<Value> delinearized;
+    if (map.getNumResults() > 1) {
+      delinearized = rewriter
+                         .create<mlir::affine::AffineDelinearizeIndexOp>(
+                             newWarpOp.getLoc(), newWarpOp.getLaneid(),
+                             delinearizedIdSizes)
+                         .getResults();
+    } else {
+      delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
+    }
+
     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
     Location loc = newWriteOp.getLoc();
     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
@@ -539,11 +566,11 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
         continue;
       unsigned indexPos = indexExpr.getPosition();
       unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
+      Value laneId = delinearized[vectorPos];
       auto scale =
           rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
       indices[indexPos] = affine::makeComposedAffineApply(
-          rewriter, loc, d0 + scale * d1,
-          {indices[indexPos], newWarpOp.getLaneid()});
+          rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
     }
     newWriteOp.getIndicesMutable().assign(indices);
 
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 9072603734879e..bf90c4a6ebb3c2 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1559,3 +1559,28 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
 //       CHECK-PROP:   %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
 //       CHECK-PROP:   %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
 //       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
+
+// -----
+
+func.func @warp_propagate_nd_write(%laneid: index, %dest: memref<4x1024xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.warp_execute_on_lane_0(%laneid)[32] -> () {
+    %0 = "some_def"() : () -> (vector<4x1024xf32>)
+    vector.transfer_write %0, %dest[%c0, %c0] : vector<4x1024xf32>, memref<4x1024xf32>
+    vector.yield
+  }
+  return
+}
+
+//       CHECK-DIST-AND-PROP: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 128)>
+
+// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_nd_write(
+//       CHECK-DIST-AND-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1x128xf32>) {
+//       CHECK-DIST-AND-PROP:     %[[V0:.*]] = "some_def"
+//       CHECK-DIST-AND-PROP:     vector.yield %[[V0]]
+//  CHECK-DIST-AND-PROP-SAME:       vector<4x1024xf32>
+//       CHECK-DIST-AND-PROP:   }
+
+//       CHECK-DIST-AND-PROP:   %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (%c4, %c8) : index, index
+//       CHECK-DIST-AND-PROP:   %[[INNER_ID:.+]] = affine.apply #map()[%[[IDS]]#1]
+//       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]], %{{.*}}[%[[IDS]]#0, %[[INNER_ID]]] {{.*}} : vector<1x128xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 178a58e796b246..d8453c78d52d4f 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -638,7 +638,7 @@ struct TestVectorDistribution
       OpBuilder builder(val.getContext());
       if (vecRank == 0)
         return AffineMap::get(val.getContext());
-      return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+      return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
     };
     auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
                         Value srcIdx, int64_t warpSz) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/83215


More information about the Mlir-commits mailing list