[Mlir-commits] [mlir] [mlir][vector] Support warp distribution of `transfer_read` with dependencies (PR #77779)

Matthias Springer llvmlistbot at llvm.org
Thu Jan 11 07:07:47 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/77779

Support distribution of `vector.transfer_read` ops when operands are defined inside of the region of `warp_execute_on_lane_0` (except for the buffer from which the op is reading).

Such IR was previously not supported. This commit changes the implementation such that indices and the padding value are also distributed.

This commit simplifies the implementation considerably: the original implementation created a new `transfer_read` op and then checked if this new op is valid. If not, the rewrite pattern failed. This was a bit hacky. It was also a violation of the rewrite pattern API (detected by `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`) because the IR was modified, but the pattern returned "failure".

>From 0aa39db39331fcb0bb0b54898dfbdf55e178945d Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 11 Jan 2024 15:06:49 +0000
Subject: [PATCH] [mlir][vector] Support warp distribution of `transfer_read`
 with dependencies

Support distribution of `vector.transfer_read` ops when operands are defined inside of the region of `warp_execute_on_lane_0` (except for the buffer from which the op is reading).

Such IR was previously not supported. This commit changes the implementation such that indices and the padding value are also distributed.

This commit simplifies the implementation considerably: the original implementation created a new `transfer_read` op and then checked if this new op is valid. If not, the rewrite pattern failed. This was a bit hacky. It was also a violation of the rewrite pattern API (detected by `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`) because the IR was modified, but the pattern returned "failure".
---
 .../Vector/Transforms/VectorDistribute.cpp    | 112 ++++++++----------
 .../Vector/vector-warp-distribute.mlir        |  19 +--
 2 files changed, 60 insertions(+), 71 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 074356ab425377..70fb6320845a47 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -822,6 +822,10 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return failure();
     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
 
+    // Source must be defined outside of the region.
+    if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
+      return failure();
+
     unsigned operandIndex = operand->getOperandNumber();
     Value distributedVal = warpOp.getResult(operandIndex);
 
@@ -832,10 +836,25 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     AffineMap map = calculateImplicitMap(sequentialType, distributedType);
     AffineMap indexMap = map.compose(read.getPermutationMap());
 
-    // Distribute the mask if present.
+    // Try to delinearize the lane ID to match the rank expected for
+    // distribution.
+    SmallVector<Value> delinearizedIds;
+    if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
+                           distributedType.getShape(), warpOp.getWarpSize(),
+                           warpOp.getLaneid(), delinearizedIds)) {
+      return rewriter.notifyMatchFailure(
+          read, "cannot delinearize lane ID for distribution");
+    }
+    assert(!delinearizedIds.empty() || map.getNumResults() == 0);
+
+    // Distribute indices and the mask (if present).
     OpBuilder::InsertionGuard g(rewriter);
-    WarpExecuteOnLane0Op newWarpOp = warpOp;
-    Value newMask = read.getMask();
+    SmallVector<Value> additionalResults(indices.begin(), indices.end());
+    SmallVector<Type> additionalResultTypes(indices.size(),
+                                            rewriter.getIndexType());
+    additionalResults.push_back(read.getPadding());
+    additionalResultTypes.push_back(read.getPadding().getType());
+
     bool hasMask = false;
     if (read.getMask()) {
       hasMask = true;
@@ -849,39 +868,22 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
         return failure();
       VectorType maskType =
           getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
-      SmallVector<size_t> newRetIndices;
-      newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-          rewriter, warpOp, ValueRange{read.getMask()}, TypeRange{maskType},
-          newRetIndices);
-      newMask = newWarpOp.getResult(newRetIndices[0]);
-      distributedVal = newWarpOp.getResult(operandIndex);
-    } else {
-      // This pattern does not actually change the warp op directly. Instead it
-      // just rewrites a new transfer read (when not masked) outside of the warp
-      // op and replaces the correponding result. There are then follow up
-      // patterns to erase now dead results of the warp op. This erasure allows
-      // propagation to continue, but this pattern on its own never actually
-      // tells the pattern rewriter that the warp op "changed." Notify the
-      // rewriter here that the warp op is changing. Similar situations are
-      // noted in following patterns.
-      rewriter.startRootUpdate(warpOp);
+      additionalResults.push_back(read.getMask());
+      additionalResultTypes.push_back(maskType);
     }
 
-    rewriter.setInsertionPointAfter(newWarpOp);
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, additionalResults, additionalResultTypes,
+        newRetIndices);
+    distributedVal = newWarpOp.getResult(operandIndex);
 
-    // Try to delinearize the lane ID to match the rank expected for
-    // distribution.
-    SmallVector<Value> delinearizedIds;
-    if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
-                           distributedType.getShape(), newWarpOp.getWarpSize(),
-                           newWarpOp.getLaneid(), delinearizedIds)) {
-      if (!hasMask)
-        rewriter.cancelRootUpdate(warpOp);
-      return rewriter.notifyMatchFailure(
-          read, "cannot delinearize lane ID for distribution");
-    }
-    assert(!delinearizedIds.empty() || map.getNumResults() == 0);
+    // Distributed indices were appended first.
+    SmallVector<Value> newIndices;
+    for (int64_t i = 0, e = indices.size(); i < e; ++i)
+      newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
 
+    rewriter.setInsertionPointAfter(newWarpOp);
     for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
       AffineExpr d0, d1;
       bindDims(read.getContext(), d0, d1);
@@ -891,42 +893,23 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
       unsigned indexPos = indexExpr.getPosition();
       unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
       int64_t scale = distributedType.getDimSize(vectorPos);
-      indices[indexPos] = affine::makeComposedAffineApply(
+      newIndices[indexPos] = affine::makeComposedAffineApply(
           rewriter, read.getLoc(), d0 + scale * d1,
-          {indices[indexPos], delinearizedIds[vectorPos]});
+          {newIndices[indexPos], delinearizedIds[vectorPos]});
     }
+
+    // Distributed padding value was appended right after the indices.
+    Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
+    // Distributed mask value was added at the end (if the op has a mask).
+    Value newMask =
+        hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
+                : Value();
     auto newRead = rewriter.create<vector::TransferReadOp>(
-        read.getLoc(), distributedVal.getType(), read.getSource(), indices,
-        read.getPermutationMapAttr(), read.getPadding(), newMask,
+        read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
+        read.getPermutationMapAttr(), newPadding, newMask,
         read.getInBoundsAttr());
 
-    // Check that the produced operation is legal.
-    // The transfer op may be reading from values that are defined within
-    // warpOp's body, which is illegal.
-    // We do the check late because incdices may be changed by
-    // makeComposeAffineApply. This rewrite may remove dependencies from
-    // warpOp's body.
-    // E.g., warpop {
-    //   %idx = affine.apply...[%outsideDef]
-    //   ... = transfer_read ...[%idx]
-    // }
-    // will be rewritten in:
-    // warpop {
-    // }
-    //  %new_idx = affine.apply...[%outsideDef]
-    //   ... = transfer_read ...[%new_idx]
-    if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
-          return (newRead.getMask() && value == newRead.getMask()) ||
-                 newWarpOp.isDefinedOutsideOfRegion(value);
-        })) {
-      if (!hasMask)
-        rewriter.cancelRootUpdate(warpOp);
-      return failure();
-    }
-
     rewriter.replaceAllUsesWith(distributedVal, newRead);
-    if (!hasMask)
-      rewriter.finalizeRootUpdate(warpOp);
     return success();
   }
 };
@@ -1315,6 +1298,11 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     unsigned int operandNumber = operand->getOperandNumber();
     auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
     VectorType extractSrcType = extractOp.getSourceVectorType();
+    // TODO: Supported shuffle types should be parameterizable, similar to
+    // `WarpShuffleFromIdxFn`.
+    if (!extractSrcType.getElementType().isF32() &&
+        !extractSrcType.getElementType().isInteger(32))
+      return failure();
     bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
     Type elType = extractSrcType.getElementType();
     VectorType distributedVecType;
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index e04fa64f0f8a70..7d54d8cfd602bc 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1248,12 +1248,12 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
 
 // -----
 
-// Check that we don't propagate transfer_reads that have dependencies on
-// values inside the warp_execute_on_lane_0.
-// In this case, propagating would create transfer_read that depends on the
-// extractelment defined in the body.
+// Make sure that all operands of the transfer_read op are properly propagated.
+// The vector.extractelement op cannot be propagated because index-typed
+// shuffles are not supported at the moment.
 
-// CHECK-PROP-LABEL: func @transfer_read_no_prop(
+// CHECK-PROP: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-PROP-LABEL: func @transfer_read_prop_operands(
 //  CHECK-PROP-SAME:     %[[IN2:[^ :]*]]: vector<1x2xindex>,
 //  CHECK-PROP-SAME:     %[[AR1:[^ :]*]]: memref<1x4x2xi32>,
 //  CHECK-PROP-SAME:     %[[AR2:[^ :]*]]: memref<1x4x1024xf32>)
@@ -1264,10 +1264,11 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
 //       CHECK-PROP:     %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32>
 //       CHECK-PROP:     %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex>
 //       CHECK-PROP:     %[[EXTRACTELT:.*]] = vector.extractelement %[[CAST]][{{.*}}: i32] : vector<64xindex>
-//       CHECK-PROP:     %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[EXTRACTELT]], %[[C0]]],
-//       CHECK-PROP:     vector.yield %[[TRANSFERREAD]] : vector<64xf32>
-//       CHECK-PROP:   return %[[W]]
-func.func @transfer_read_no_prop(%in2: vector<1x2xindex>, %ar1 :  memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> {
+//       CHECK-PROP:     vector.yield %[[EXTRACTELT]] : index
+//       CHECK-PROP:     %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[THREADID]]]
+//       CHECK-PROP:   %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[W]], %[[APPLY]]],
+//       CHECK-PROP:   return %[[TRANSFERREAD]]
+func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 :  memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> {
   %0 = gpu.thread_id  x
   %c0_i32 = arith.constant 0 : i32
   %c0 = arith.constant 0 : index



More information about the Mlir-commits mailing list