[Mlir-commits] [mlir] aa2376a - [mlir][vector] Notify the rewriter when sinking out of warp ops (#71964)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 10 11:45:21 PST 2023


Author: Quinn Dawkins
Date: 2023-11-10T14:45:18-05:00
New Revision: aa2376a0835f8e2a92fc84fe172f67b16c298361

URL: https://github.com/llvm/llvm-project/commit/aa2376a0835f8e2a92fc84fe172f67b16c298361
DIFF: https://github.com/llvm/llvm-project/commit/aa2376a0835f8e2a92fc84fe172f67b16c298361.diff

LOG: [mlir][vector] Notify the rewriter when sinking out of warp ops (#71964)

A number of the warp distribution patterns work by rewriting a warp op
in place by moving a contained op outside. This notifies the rewriter
that the warp op is changing in this case.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 645caa9c1378821..ac73cf07004ded8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -645,6 +645,11 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
     });
     if (!yieldOperand)
       return failure();
+
+    // Notify the rewriter that the warp op is changing (see the comment on
+    // the WarpOpTransferRead pattern).
+    rewriter.startRootUpdate(warpOp);
+
     Operation *elementWise = yieldOperand->get().getDefiningOp();
     unsigned operandIndex = yieldOperand->getOperandNumber();
     Value distributedVal = warpOp.getResult(operandIndex);
@@ -683,6 +688,7 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
         {newWarpOp.getResult(operandIndex).getType()});
     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
                                 newOp->getResult(0));
+    rewriter.finalizeRootUpdate(warpOp);
     return success();
   }
 };
@@ -713,6 +719,9 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
     auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
     if (!dense)
       return failure();
+    // Notify the rewriter that the warp op is changing (see the comment on
+    // the WarpOpTransferRead pattern).
+    rewriter.startRootUpdate(warpOp);
     unsigned operandIndex = yieldOperand->getOperandNumber();
     Attribute scalarAttr = dense.getSplatValue<Attribute>();
     auto newAttr = DenseElementsAttr::get(
@@ -721,6 +730,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
     rewriter.setInsertionPointAfter(warpOp);
     Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
     rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
+    rewriter.finalizeRootUpdate(warpOp);
     return success();
   }
 };
@@ -823,7 +833,9 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     OpBuilder::InsertionGuard g(rewriter);
     WarpExecuteOnLane0Op newWarpOp = warpOp;
     Value newMask = read.getMask();
+    bool hasMask = false;
     if (read.getMask()) {
+      hasMask = true;
       // TODO: Distribution of masked reads with non-trivial permutation maps
       // requires the distribution of the mask to elementwise match the
       // distribution of the permuted written vector. Currently the details
@@ -840,6 +852,16 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
           newRetIndices);
       newMask = newWarpOp.getResult(newRetIndices[0]);
       distributedVal = newWarpOp.getResult(operandIndex);
+    } else {
+      // This pattern does not actually change the warp op directly. Instead it
+      // just rewrites a new transfer read (when not masked) outside of the warp
+      // op and replaces the correponding result. There are then follow up
+      // patterns to erase now dead results of the warp op. This erasure allows
+      // propagation to continue, but this pattern on its own never actually
+      // tells the pattern rewriter that the warp op "changed." Notify the
+      // rewriter here that the warp op is changing. Similar situations are
+      // noted in following patterns.
+      rewriter.startRootUpdate(warpOp);
     }
 
     rewriter.setInsertionPointAfter(newWarpOp);
@@ -849,9 +871,12 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     SmallVector<Value> delinearizedIds;
     if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
                            distributedType.getShape(), newWarpOp.getWarpSize(),
-                           newWarpOp.getLaneid(), delinearizedIds))
+                           newWarpOp.getLaneid(), delinearizedIds)) {
+      if (!hasMask)
+        rewriter.cancelRootUpdate(warpOp);
       return rewriter.notifyMatchFailure(
           read, "cannot delinearize lane ID for distribution");
+    }
     assert(!delinearizedIds.empty() || map.getNumResults() == 0);
 
     for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
@@ -890,10 +915,15 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
           return (newRead.getMask() && value == newRead.getMask()) ||
                  newWarpOp.isDefinedOutsideOfRegion(value);
-        }))
+        })) {
+      if (!hasMask)
+        rewriter.cancelRootUpdate(warpOp);
       return failure();
+    }
 
     rewriter.replaceAllUsesWith(distributedVal, newRead);
+    if (!hasMask)
+      rewriter.finalizeRootUpdate(warpOp);
     return success();
   }
 };
@@ -996,7 +1026,11 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
     }
     if (!valForwarded)
       return failure();
+    // Notify the rewriter that the warp op is changing (see the comment on
+    // the WarpOpTransferRead pattern).
+    rewriter.startRootUpdate(warpOp);
     rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
+    rewriter.finalizeRootUpdate(warpOp);
     return success();
   }
 };
@@ -1024,6 +1058,9 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
         vector::BroadcastableToResult::Success)
       return failure();
+    // Notify the rewriter that the warp op is changing (see the comment on
+    // the WarpOpTransferRead pattern).
+    rewriter.startRootUpdate(warpOp);
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
@@ -1032,6 +1069,7 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
         loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
                                 broadcasted);
+    rewriter.finalizeRootUpdate(warpOp);
     return success();
   }
 };
@@ -1046,6 +1084,7 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
         warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
     if (!operand)
       return failure();
+
     auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
 
     unsigned int operandNumber = operand->getOperandNumber();
@@ -1133,6 +1172,10 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
           mask, "cannot delinearize lane ID for distribution");
     assert(!delinearizedIds.empty());
 
+    // Notify the rewriter that the warp op is changing (see the comment on
+    // the WarpOpTransferRead pattern).
+    rewriter.startRootUpdate(warpOp);
+
     AffineExpr s0, s1;
     bindSymbols(rewriter.getContext(), s0, s1);
     SmallVector<Value> newOperands;
@@ -1151,6 +1194,7 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
     auto newMask =
         rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
     rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
+    rewriter.finalizeRootUpdate(warpOp);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 1821190c44e3af4..8056260f4610977 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1351,6 +1351,26 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
 
 // -----
 
+func.func @warp_propagate_masked_transfer_read_shared_mask(%laneid: index, %src: memref<4096x4096xf32>, %index: index, %index2: index, %mask_ub: index) -> (vector<2xf32>, vector<2xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>, vector<2xf32>) {
+    %mask = vector.create_mask %mask_ub: vector<128xi1>
+    %0 = vector.transfer_read %src[%c0, %index], %f0, %mask {in_bounds = [true]} : memref<4096x4096xf32>, vector<128xf32>
+    %1 = vector.transfer_read %src[%c0, %index2], %f0, %mask {in_bounds = [true]} : memref<4096x4096xf32>, vector<128xf32>
+    vector.yield %0, %1 : vector<128xf32>, vector<128xf32>
+  }
+  return %r#0, %r#1 : vector<2xf32>, vector<2xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_masked_transfer_read_shared_mask
+//       CHECK-PROP:   vector.create_mask %{{.*}} : vector<2xi1>
+//       CHECK-PROP:   vector.transfer_read %{{.*}} : memref<4096x4096xf32>, vector<2xf32>
+//       CHECK-PROP:   vector.create_mask %{{.*}} : vector<2xi1>
+//       CHECK-PROP:   vector.transfer_read %{{.*}} : memref<4096x4096xf32>, vector<2xf32>
+
+// -----
+
 func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref<128xf32>, %f1: f32) -> (vector<2xf32>, vector<4xf32>) {
   %f0 = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list