[Mlir-commits] [mlir] [mlir][vector] Fix cases with multiple yielded transfer_read ops (PR #71625)

Quinn Dawkins llvmlistbot at llvm.org
Thu Nov 9 06:08:58 PST 2023


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/71625

>From fcd6163ec85dc3526d8bfe87c6a7b408a7565a1c Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Mon, 6 Nov 2023 18:42:43 -0500
Subject: [PATCH 1/2] [mlir][vector] Fix cases with multiple yielded
 transfer_read ops

This fixes two bugs:
1) When deciding whether a transfer read could be propagated out of
   a warp op, it looked for the first yield operand that was produced by
   a transfer read. If this transfer read wasn't ready to be
   distributed, the pattern would not re-check for any other transfer
   reads that could have been propagated.
2) When dropping dead warp results, we do so by updating the warp op
   signature and splicing in the old region. This does not add the ops
   in the body of the warp op back to the pattern applicator's worklist,
   and thus those operations won't be DCE'd. This is a problem for
   patterns like the one for transfer reads that will still see the dead
   operation as a user.
---
 .../Vector/Transforms/VectorDistribute.cpp    | 40 +++++++++++++++----
 .../Vector/vector-warp-distribute.mlir        | 37 +++++++++++++++++
 2 files changed, 70 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..f67e03510ba6ca6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -801,13 +801,31 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
-    if (!operand)
-      return failure();
-    auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
-    // Don't duplicate transfer_read ops when distributing.
-    if (!read.getResult().hasOneUse())
+    // Try to find a distributable yielded read. Note that this pattern can
+    // still fail at the end after distribution, in which case this might have
+    // missed another distributable read.
+    vector::TransferReadOp read;
+    auto yield = cast<vector::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    OpOperand *operand;
+    for (OpOperand &yieldOperand : yield->getOpOperands()) {
+      Value yieldValues = yieldOperand.get();
+      Operation *definedOp = yieldValues.getDefiningOp();
+      if (!definedOp)
+        continue;
+      auto maybeRead = dyn_cast<vector::TransferReadOp>(definedOp);
+      if (!maybeRead)
+        continue;
+      if (warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        continue;
+      // Don't duplicate transfer_read ops when distributing.
+      if (!maybeRead.getResult().hasOneUse())
+        continue;
+      read = maybeRead;
+      operand = &yieldOperand;
+      break;
+    }
+    if (!read)
       return failure();
     unsigned operandIndex = operand->getOperandNumber();
     Value distributedVal = warpOp.getResult(operandIndex);
@@ -913,6 +931,14 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Move the body of the old warpOp to a new warpOp.
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
         rewriter, warpOp, newYieldValues, newResultTypes);
+
+    // Simplify the new warp op after dropping dead results.
+    auto simplifyFn = [&](Operation *op) {
+      if (isOpTriviallyDead(op))
+        rewriter.eraseOp(op);
+    };
+    newWarpOp.getBody()->walk(simplifyFn);
+
     // Replace results of the old warpOp by the new, deduplicated results.
     SmallVector<Value> newValues;
     newValues.reserve(warpOp->getNumResults());
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..3f95a39100b2b88 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1256,6 +1256,43 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
 
 // -----
 
+func.func @warp_propagate_multi_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>, vector<1xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>, vector<1xf32>) {
+    %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+    "some_use"(%0) : (vector<1xf32>) -> ()
+    %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+    vector.yield %0, %1 : vector<1xf32>, vector<1xf32>
+  }
+  return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_multi_transfer_read
+//       CHECK-PROP:   vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
+//       CHECK-PROP:     %[[INNER_READ:.+]] = vector.transfer_read
+//       CHECK-PROP:     "some_use"(%[[INNER_READ]])
+//       CHECK-PROP:     vector.yield %[[INNER_READ]] : vector<1xf32>
+//       CHECK-PROP:   vector.transfer_read
+
+// -----
+
+func.func @warp_propagate_dead_user_multi_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) {
+    %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+    %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+    %max = arith.maximumf %0, %1 : vector<64xf32>
+    vector.yield %max : vector<64xf32>
+  }
+  return %r : vector<1xf32>
+}
+
+//   CHECK-PROP-LABEL: func.func @warp_propagate_dead_user_multi_read
+// CHECK-PROP-COUNT-2:   vector.transfer_read {{.*}} vector<1xf32>
+//         CHECK-PROP:   arith.maximumf {{.*}} : vector<1xf32>
+
+// -----
+
 func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
   %c0 = arith.constant 0 : index
   vector.warp_execute_on_lane_0(%laneid)[32] -> () {

>From bfa040d20b0b4f44712ac0e45060b090949dc223 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 9 Nov 2023 09:07:17 -0500
Subject: [PATCH 2/2] Cleanup logic getting transfer read warp results

---
 .../Vector/Transforms/VectorDistribute.cpp    | 32 +++++--------------
 1 file changed, 8 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index f67e03510ba6ca6..b6d9c804a3c0fb4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -804,29 +804,14 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Try to find a distributable yielded read. Note that this pattern can
     // still fail at the end after distribution, in which case this might have
     // missed another distributable read.
-    vector::TransferReadOp read;
-    auto yield = cast<vector::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-    OpOperand *operand;
-    for (OpOperand &yieldOperand : yield->getOpOperands()) {
-      Value yieldValues = yieldOperand.get();
-      Operation *definedOp = yieldValues.getDefiningOp();
-      if (!definedOp)
-        continue;
-      auto maybeRead = dyn_cast<vector::TransferReadOp>(definedOp);
-      if (!maybeRead)
-        continue;
-      if (warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
-        continue;
+    OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
       // Don't duplicate transfer_read ops when distributing.
-      if (!maybeRead.getResult().hasOneUse())
-        continue;
-      read = maybeRead;
-      operand = &yieldOperand;
-      break;
-    }
-    if (!read)
+      return isa<vector::TransferReadOp>(op) && op->hasOneUse();
+    });
+    if (!operand)
       return failure();
+    auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
+
     unsigned operandIndex = operand->getOperandNumber();
     Value distributedVal = warpOp.getResult(operandIndex);
 
@@ -933,11 +918,10 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
         rewriter, warpOp, newYieldValues, newResultTypes);
 
     // Simplify the new warp op after dropping dead results.
-    auto simplifyFn = [&](Operation *op) {
+    newWarpOp.getBody()->walk([&](Operation *op) {
       if (isOpTriviallyDead(op))
         rewriter.eraseOp(op);
-    };
-    newWarpOp.getBody()->walk(simplifyFn);
+    });
 
     // Replace results of the old warpOp by the new, deduplicated results.
     SmallVector<Value> newValues;



More information about the Mlir-commits mailing list