[Mlir-commits] [mlir] df49a97 - [mlir][vector] Root the transfer write distribution pattern on the warp op (#71868)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 10 05:49:36 PST 2023


Author: Quinn Dawkins
Date: 2023-11-10T08:49:33-05:00
New Revision: df49a97ab2e952d1588d9e7784987d982ebf3365

URL: https://github.com/llvm/llvm-project/commit/df49a97ab2e952d1588d9e7784987d982ebf3365
DIFF: https://github.com/llvm/llvm-project/commit/df49a97ab2e952d1588d9e7784987d982ebf3365.diff

LOG: [mlir][vector] Root the transfer write distribution pattern on the warp op (#71868)

Currently when there is a mix of transfer read ops and transfer write
ops that need to be distributed, because the pattern for write
distribution is rooted on the transfer write, it is hard to guarantee
that the write gets distributed after the read when the two aren't
directly connected by SSA. This is likely still relatively unsafe when
there are undistributable ops, but structurally these patterns are a bit
difficult to work with. For now pattern benefits give fairly good
guarantees for happy paths.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index a76a58eb5ec6d3c..f32efd94691a410 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -59,9 +59,15 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
 ///   vector.yield %v : vector<32xf32>
 /// }
 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
+///
+/// When applied at the same time as the vector propagation patterns,
+/// distribution of `vector.transfer_write` is expected to have the highest
+/// priority (pattern benefit). By making propagation of `vector.transfer_read`
+/// be the lowest priority pattern, it will be the last vector operation to
+/// distribute, meaning writes should propagate first.
 void populateDistributeTransferWriteOpPatterns(
     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
-    PatternBenefit benefit = 1);
+    PatternBenefit benefit = 2);
 
 /// Move scalar operations with no dependency on the warp op outside of the
 /// region.
@@ -75,10 +81,19 @@ using WarpShuffleFromIdxFn =
 /// Collect patterns to propagate warp distribution. `distributionMapFn` is used
 /// to decide how a value should be distributed when this cannot be inferred
 /// from its uses.
+///
+/// The separate control over the `vector.transfer_read` op pattern benefit
+/// is given to ensure the order of reads/writes before and after distribution
+/// is consistent. As noted above, writes are expected to have the highest
+/// priority for distribution, but are only ever distributed if adjacent to the
+/// yield. By making reads the lowest priority pattern, it will be the last
+/// vector operation to distribute, meaning writes should propagate first. This
+/// is relatively brittle when ops fail to distribute, but that is a limitation
+/// of these propagation patterns when there is a dependency not modeled by SSA.
 void populatePropagateWarpVectorDistributionPatterns(
     RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn,
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
-    PatternBenefit benefit = 1);
+    PatternBenefit benefit = 1, PatternBenefit readBenefit = 0);
 
 /// Lambda signature to compute a reduction of a distributed value for the given
 /// reduction kind and size.

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index ac2a23221ad5093..334d23e08419cea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -474,10 +474,10 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
 ///   vector.yield %v : vector<32xf32>
 /// }
 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
-struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
+struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
                       PatternBenefit b = 1)
-      : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
+      : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
         distributionMapFn(std::move(fn)) {}
 
   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
@@ -584,18 +584,15 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     return success();
   }
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
-    if (!warpOp)
+    auto yield = cast<vector::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    Operation *lastNode = yield->getPrevNode();
+    auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
+    if (!writeOp)
       return failure();
 
-    // There must be no op with a side effect after writeOp.
-    Operation *nextOp = writeOp.getOperation();
-    while ((nextOp = nextOp->getNextNode()))
-      if (!isMemoryEffectFree(nextOp))
-        return failure();
-
     Value maybeMask = writeOp.getMask();
     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
           return writeOp.getVector() == value ||
@@ -1731,11 +1728,13 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
 
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
-    const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
-  patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
-               WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
-               WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
-               WarpOpInsert>(patterns.getContext(), benefit);
+    const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
+    PatternBenefit readBenefit) {
+  patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
+  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
+               WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
+      patterns.getContext(), benefit);
   patterns.add<WarpOpExtractElement>(patterns.getContext(),
                                      warpShuffleFromIdxFn, benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 41b3d5d97728c5b..6d8ad5a0e88c2bd 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1348,3 +1348,24 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
 //       CHECK-PROP:   vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[ARG2]]], {{.*}}, %[[R]]#1 {{.*}} vector<2x2xf32>
 //       CHECK-PROP:   %[[DIST_READ_IDX1:.+]] = affine.apply #[[$MAP1]]()[%[[ARG2]], %[[ARG0]]]
 //       CHECK-PROP:   vector.transfer_read {{.*}}[%[[C0]], %[[DIST_READ_IDX1]]], {{.*}}, %[[R]]#0 {{.*}} vector<2xf32>
+
+// -----
+
+func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref<128xf32>, %f1: f32) -> (vector<2xf32>, vector<4xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>, vector<4xf32>) {
+    %cst = arith.constant dense<2.0> : vector<128xf32>
+    %0 = vector.transfer_read %buffer[%c0], %f0 {in_bounds = [true]} : memref<128xf32>, vector<128xf32>
+    vector.transfer_write %cst, %buffer[%c0] : vector<128xf32>, memref<128xf32>
+    %1 = vector.broadcast %f1 : f32 to vector<64xf32>
+    vector.yield %1, %0 : vector<64xf32>, vector<128xf32>
+  }
+  return %r#0, %r#1 : vector<2xf32>, vector<4xf32>
+}
+
+// Verify that the write comes after the read
+// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_unconnected_read_write(
+//       CHECK-DIST-AND-PROP:   %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+//       CHECK-DIST-AND-PROP:   vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
+//       CHECK-DIST-AND-PROP:   vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 2fbf1babf437f08..1a177fa31de37ce 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -594,12 +594,19 @@ struct TestVectorDistribution
                          .getResult(0);
       return result;
     };
-    if (distributeTransferWriteOps) {
+    if (distributeTransferWriteOps && propagateDistribution) {
+      RewritePatternSet patterns(ctx);
+      vector::populatePropagateWarpVectorDistributionPatterns(
+          patterns, distributionFn, shuffleFn, /*benefit=*/1,
+          /*readBenefit=*/0);
+      vector::populateDistributeReduction(patterns, warpReduction, 1);
+      populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
+      (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    } else if (distributeTransferWriteOps) {
       RewritePatternSet patterns(ctx);
       populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-    }
-    if (propagateDistribution) {
+    } else if (propagateDistribution) {
       RewritePatternSet patterns(ctx);
       vector::populatePropagateWarpVectorDistributionPatterns(
           patterns, distributionFn, shuffleFn);


        


More information about the Mlir-commits mailing list