[Mlir-commits] [mlir] [mlir][vector] Root the transfer write distribution pattern on the warp op (PR #71868)

Quinn Dawkins llvmlistbot at llvm.org
Thu Nov 9 13:07:36 PST 2023


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/71868

Currently when there is a mix of transfer read ops and transfer write ops that need to be distributed, because the pattern for write distribution is rooted on the transfer write, it is hard to guarantee that the write gets distributed after the read when the two aren't directly connected by SSA. This is likely still relatively unsafe when there are undistributable ops, but structurally these patterns are a bit difficult to work with. For now pattern benefits give fairly good guarantees for happy paths.

>From d537cf66c843661d0c6971bbab943a3c33036d91 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 9 Nov 2023 11:37:46 -0500
Subject: [PATCH] [mlir][vector] Root the transfer write distribution pattern
 on the warp op

Currently when there is a mix of transfer read ops and transfer write
ops that need to be distributed, because the pattern for write
distribution is rooted on the transfer write, it is hard to guarantee
that the write gets distributed after the read when the two aren't
directly connected by SSA.
---
 .../Vector/Transforms/VectorDistribution.h    | 12 +++++--
 .../Vector/Transforms/VectorDistribute.cpp    | 31 +++++++++----------
 .../Vector/vector-warp-distribute.mlir        | 21 +++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   | 13 ++++++--
 4 files changed, 56 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index a76a58eb5ec6d3c..d037ed1d142b1d9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -61,7 +61,7 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
 void populateDistributeTransferWriteOpPatterns(
     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
-    PatternBenefit benefit = 1);
+    PatternBenefit benefit = 2);
 
 /// Move scalar operations with no dependency on the warp op outside of the
 /// region.
@@ -75,10 +75,18 @@ using WarpShuffleFromIdxFn =
 /// Collect patterns to propagate warp distribution. `distributionMapFn` is used
 /// to decide how a value should be distributed when this cannot be inferred
 /// from its uses.
+///
+/// Added control over the pattern benefit for propagating
+/// `vector.transfer_read` ops is given to ensure the order of reads/writes
+/// before and after distribution is consistent. Writes are expected to have
+/// the highest priority for distribution, but is only ever distributed if it
+/// is adjacent to the yield. By making reads the lowest priority pattern, it
+/// will be the last pure vector operation to distribute, meaning writes should
+/// propagate first.
 void populatePropagateWarpVectorDistributionPatterns(
     RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn,
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
-    PatternBenefit benefit = 1);
+    PatternBenefit benefit = 1, PatternBenefit readBenefit = 0);
 
 /// Lambda signature to compute a reduction of a distributed value for the given
 /// reduction kind and size.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index ac2a23221ad5093..334d23e08419cea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -474,10 +474,10 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
 ///   vector.yield %v : vector<32xf32>
 /// }
 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
-struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
+struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
                       PatternBenefit b = 1)
-      : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
+      : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
         distributionMapFn(std::move(fn)) {}
 
   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
@@ -584,18 +584,15 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     return success();
   }
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
-    if (!warpOp)
+    auto yield = cast<vector::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    Operation *lastNode = yield->getPrevNode();
+    auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
+    if (!writeOp)
       return failure();
 
-    // There must be no op with a side effect after writeOp.
-    Operation *nextOp = writeOp.getOperation();
-    while ((nextOp = nextOp->getNextNode()))
-      if (!isMemoryEffectFree(nextOp))
-        return failure();
-
     Value maybeMask = writeOp.getMask();
     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
           return writeOp.getVector() == value ||
@@ -1731,11 +1728,13 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
 
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
-    const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
-  patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
-               WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
-               WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
-               WarpOpInsert>(patterns.getContext(), benefit);
+    const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
+    PatternBenefit readBenefit) {
+  patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
+  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
+               WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
+      patterns.getContext(), benefit);
   patterns.add<WarpOpExtractElement>(patterns.getContext(),
                                      warpShuffleFromIdxFn, benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 41b3d5d97728c5b..6d8ad5a0e88c2bd 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1348,3 +1348,24 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
 //       CHECK-PROP:   vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[ARG2]]], {{.*}}, %[[R]]#1 {{.*}} vector<2x2xf32>
 //       CHECK-PROP:   %[[DIST_READ_IDX1:.+]] = affine.apply #[[$MAP1]]()[%[[ARG2]], %[[ARG0]]]
 //       CHECK-PROP:   vector.transfer_read {{.*}}[%[[C0]], %[[DIST_READ_IDX1]]], {{.*}}, %[[R]]#0 {{.*}} vector<2xf32>
+
+// -----
+
+func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref<128xf32>, %f1: f32) -> (vector<2xf32>, vector<4xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>, vector<4xf32>) {
+    %cst = arith.constant dense<2.0> : vector<128xf32>
+    %0 = vector.transfer_read %buffer[%c0], %f0 {in_bounds = [true]} : memref<128xf32>, vector<128xf32>
+    vector.transfer_write %cst, %buffer[%c0] : vector<128xf32>, memref<128xf32>
+    %1 = vector.broadcast %f1 : f32 to vector<64xf32>
+    vector.yield %1, %0 : vector<64xf32>, vector<128xf32>
+  }
+  return %r#0, %r#1 : vector<2xf32>, vector<4xf32>
+}
+
+// Verify that the write comes after the read
+// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_unconnected_read_write(
+//       CHECK-DIST-AND-PROP:   %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+//       CHECK-DIST-AND-PROP:   vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
+//       CHECK-DIST-AND-PROP:   vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 2fbf1babf437f08..1a177fa31de37ce 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -594,12 +594,19 @@ struct TestVectorDistribution
                          .getResult(0);
       return result;
     };
-    if (distributeTransferWriteOps) {
+    if (distributeTransferWriteOps && propagateDistribution) {
+      RewritePatternSet patterns(ctx);
+      vector::populatePropagateWarpVectorDistributionPatterns(
+          patterns, distributionFn, shuffleFn, /*benefit=*/1,
+          /*readBenefit=*/0);
+      vector::populateDistributeReduction(patterns, warpReduction, 1);
+      populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
+      (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    } else if (distributeTransferWriteOps) {
       RewritePatternSet patterns(ctx);
       populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-    }
-    if (propagateDistribution) {
+    } else if (propagateDistribution) {
       RewritePatternSet patterns(ctx);
       vector::populatePropagateWarpVectorDistributionPatterns(
           patterns, distributionFn, shuffleFn);



More information about the Mlir-commits mailing list