[Mlir-commits] [mlir] 7360d5d - [mlir][vector] Fix cases with multiple yielded transfer_read ops (#71625)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 9 08:35:59 PST 2023


Author: Quinn Dawkins
Date: 2023-11-09T11:35:54-05:00
New Revision: 7360d5d30fe75ee86696d1e9a8a62cb74b47254a

URL: https://github.com/llvm/llvm-project/commit/7360d5d30fe75ee86696d1e9a8a62cb74b47254a
DIFF: https://github.com/llvm/llvm-project/commit/7360d5d30fe75ee86696d1e9a8a62cb74b47254a.diff

LOG: [mlir][vector] Fix cases with multiple yielded transfer_read ops (#71625)

This fixes two bugs:
1) When deciding whether a transfer read could be propagated out of
   a warp op, it looked for the first yield operand that was produced by
   a transfer read. If this transfer read wasn't ready to be
   distributed, the pattern would not re-check for any other transfer
   reads that could have been propagated.
2) When dropping dead warp results, we do so by updating the warp op
   signature and splicing in the old region. This does not add the ops
   in the body of the warp op back to the pattern applicator's worklist,
   and thus those operations won't be DCE'd. This is a problem for
   patterns like the one for transfer reads that will still see the dead
   operation as a user.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 1975ba9c92d9988..ac2a23221ad5093 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -801,14 +801,17 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
+    // Try to find a distributable yielded read. Note that this pattern can
+    // still fail at the end after distribution, in which case this might have
+    // missed another distributable read.
+    OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+      // Don't duplicate transfer_read ops when distributing.
+      return isa<vector::TransferReadOp>(op) && op->hasOneUse();
+    });
     if (!operand)
       return failure();
     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
-    // Don't duplicate transfer_read ops when distributing.
-    if (!read.getResult().hasOneUse())
-      return failure();
+
     unsigned operandIndex = operand->getOperandNumber();
     Value distributedVal = warpOp.getResult(operandIndex);
 
@@ -937,6 +940,13 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Move the body of the old warpOp to a new warpOp.
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
         rewriter, warpOp, newYieldValues, newResultTypes);
+
+    // Simplify the new warp op after dropping dead results.
+    newWarpOp.getBody()->walk([&](Operation *op) {
+      if (isOpTriviallyDead(op))
+        rewriter.eraseOp(op);
+    });
+
     // Replace results of the old warpOp by the new, deduplicated results.
     SmallVector<Value> newValues;
     newValues.reserve(warpOp->getNumResults());

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 2a1007fbbe86435..41b3d5d97728c5b 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1256,6 +1256,43 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
 
 // -----
 
+func.func @warp_propagate_multi_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>, vector<1xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>, vector<1xf32>) {
+    %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+    "some_use"(%0) : (vector<1xf32>) -> ()
+    %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+    vector.yield %0, %1 : vector<1xf32>, vector<1xf32>
+  }
+  return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_multi_transfer_read
+//       CHECK-PROP:   vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
+//       CHECK-PROP:     %[[INNER_READ:.+]] = vector.transfer_read
+//       CHECK-PROP:     "some_use"(%[[INNER_READ]])
+//       CHECK-PROP:     vector.yield %[[INNER_READ]] : vector<1xf32>
+//       CHECK-PROP:   vector.transfer_read
+
+// -----
+
+func.func @warp_propagate_dead_user_multi_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) {
+    %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+    %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+    %max = arith.maximumf %0, %1 : vector<64xf32>
+    vector.yield %max : vector<64xf32>
+  }
+  return %r : vector<1xf32>
+}
+
+//   CHECK-PROP-LABEL: func.func @warp_propagate_dead_user_multi_read
+// CHECK-PROP-COUNT-2:   vector.transfer_read {{.*}} vector<1xf32>
+//         CHECK-PROP:   arith.maximumf {{.*}} : vector<1xf32>
+
+// -----
+
 func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
   %c0 = arith.constant 0 : index
   vector.warp_execute_on_lane_0(%laneid)[32] -> () {


        


More information about the Mlir-commits mailing list