[Mlir-commits] [mlir] [mlir][vector] Allow vector distribution with multiple written elements (PR #75122)

Jakub Kuderski llvmlistbot at llvm.org
Mon Dec 11 21:51:05 PST 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/75122

>From 6da2cf6a80b4f95370ab0212e4a5c9047f283287 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 11 Dec 2023 19:55:26 -0500
Subject: [PATCH 1/2] [mlir][vector] Allow vector distribution with multiple
 written elements

Add a configuration option to allow vector distribution with multiple
elements written by a single lane.

This is so that we can perform vector multi-reduction with multiple
results per workgroup.
---
 .../Vector/Transforms/VectorDistribution.h    |   6 +-
 .../Vector/Transforms/VectorDistribute.cpp    |  26 +++--
 .../Vector/vector-warp-distribute.mlir        | 100 +++++++++++++++++-
 .../Dialect/Vector/TestVectorTransforms.cpp   |   8 +-
 4 files changed, 123 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index f32efd94691a4..c7238fcbbb702 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -43,7 +43,9 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
 using DistributionMapFn = std::function<AffineMap(Value)>;
 
 /// Distribute transfer_write ops based on the affine map returned by
-/// `distributionMapFn`.
+/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
+/// will not be distributed.
+///
 /// Example:
 /// ```
 /// %0 = vector.warp_execute_on_lane_0(%id){
@@ -67,7 +69,7 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
 /// distribute, meaning writes should propagate first.
 void populateDistributeTransferWriteOpPatterns(
     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
-    PatternBenefit benefit = 2);
+    unsigned maxNumElementsToExtract, PatternBenefit benefit = 2);
 
 /// Move scalar operations with no dependency on the warp op outside of the
 /// region.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 07ecd88575203..1d35900a2c5db 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/Support/FormatVariadic.h"
 #include <numeric>
 #include <utility>
 
@@ -458,7 +459,9 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
 }
 
 /// Distribute transfer_write ops based on the affine map returned by
-/// `distributionMapFn`.
+/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
+/// will not be distributed.
+///
 /// Example:
 /// ```
 /// %0 = vector.warp_execute_on_lane_0(%id){
@@ -476,9 +479,10 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
 struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
-                      PatternBenefit b = 1)
+                      unsigned maxNumElementsToExtract, PatternBenefit b = 1)
       : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
-        distributionMapFn(std::move(fn)) {}
+        distributionMapFn(std::move(fn)),
+        maxNumElementsToExtract(maxNumElementsToExtract) {}
 
   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
   /// are multiples of the distribution ratio are supported at the moment.
@@ -553,10 +557,13 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
     Location loc = writeOp.getLoc();
     VectorType vecType = writeOp.getVectorType();
 
-    // Only sink out vector of 1 element for now to not serialize large vector
-    // store. This can later be controlled by user.
-    if (vecType.getNumElements() != 1)
-      return failure();
+    if (vecType.getNumElements() > maxNumElementsToExtract) {
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          llvm::formatv(
+              "writes more elements ({0}) than allowed to extract ({1})",
+              vecType.getNumElements(), maxNumElementsToExtract));
+    }
 
     // Do not process warp ops that contain only TransferWriteOps.
     if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
@@ -616,6 +623,7 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
 private:
   DistributionMapFn distributionMapFn;
+  unsigned maxNumElementsToExtract = 1;
 };
 
 /// Sink out elementwise op feeding into a warp op yield.
@@ -1833,9 +1841,9 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
 
 void mlir::vector::populateDistributeTransferWriteOpPatterns(
     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
-    PatternBenefit benefit) {
+    unsigned maxNumElementsToExtract, PatternBenefit benefit) {
   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
-                                    benefit);
+                                    maxNumElementsToExtract, benefit);
 }
 
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index ab175effa3dfb..e04fa64f0f8a7 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1,8 +1,20 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN:   --test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN:   --test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN:   --test-vector-warp-distribute="hoist-uniform distribute-transfer-write max-transfer-write-elements=4" \
+// RUN:   | FileCheck --check-prefixes=CHECK-D %s
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN:  --test-vector-warp-distribute=propagate-distribution --canonicalize \
+// RUN:  | FileCheck --check-prefixes=CHECK-PROP %s
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN:   --test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \
+// RUN:   --canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
 
 // CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)>
 // CHECK-SCF-IF-DAG: #[[$TIMES4:.*]] = affine_map<()[s0] -> (s0 * 4)>
@@ -134,6 +146,84 @@ func.func @warp_extract(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : ind
 
 // -----
 
+// Check that we can distribute writes of the maximum allowed number of elements.
+
+// CHECK-D-LABEL: func @warp_extract_4_elems(
+//       CHECK-D:   %[[WARPOP:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4x1xf32>)
+//       CHECK-D:     "test.dummy_op"
+//       CHECK-D:     "test.dummy_op"
+//       CHECK-D:     vector.yield %{{.*}}, %{{.*}} : vector<4xf32>, vector<4x1xf32>
+//       CHECK-D:   }
+//       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
+//       CHECK-D:     vector.transfer_write %[[WARPOP]]#1, %{{.*}}[%{{.*}}] {{.*}} : vector<4x1xf32>
+//       CHECK-D:   }
+//       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
+//       CHECK-D:     vector.transfer_write %[[WARPOP]]#0, %{{.*}}[%{{.*}}] {{.*}} : vector<4xf32>
+//       CHECK-D:   }
+
+func.func @warp_extract_4_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+    %c0 = arith.constant 0 : index
+    %v = "test.dummy_op"() : () -> (vector<4xf32>)
+    %v1 = "test.dummy_op"() : () -> (vector<4x1xf32>)
+    vector.transfer_write %v1, %arg1[%c0, %c0] : vector<4x1xf32>, memref<1024x1024xf32>
+    vector.transfer_write %v, %arg1[%c0, %c0] : vector<4xf32>, memref<1024x1024xf32>
+  }
+  return
+}
+
+// -----
+
+// Check that we do not distribute writes larger than the maximum allowed
+// number of elements.
+
+// CHECK-D-LABEL: func @warp_extract_5_elems(
+//       CHECK-D:   arith.constant 0 : index
+//       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
+//       CHECK-D:     %[[V:.+]] = "test.dummy_op"
+//       CHECK-D:     %[[V1:.+]] = "test.dummy_op"
+//       CHECK-D:     vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<5x1xf32>
+//       CHECK-D:     vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<5xf32>
+//       CHECK-D:   }
+
+func.func @warp_extract_5_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+    %c0 = arith.constant 0 : index
+    %v = "test.dummy_op"() : () -> (vector<5xf32>)
+    %v1 = "test.dummy_op"() : () -> (vector<5x1xf32>)
+    vector.transfer_write %v1, %arg1[%c0, %c0] : vector<5x1xf32>, memref<1024x1024xf32>
+    vector.transfer_write %v, %arg1[%c0, %c0] : vector<5xf32>, memref<1024x1024xf32>
+  }
+  return
+}
+
+// -----
+
+// Check that we do not distribute writes larger than the maximum allowed
+// number of elements, or multiples of the maximum number of elements.
+
+// CHECK-D-LABEL: func @warp_extract_8_elems(
+//       CHECK-D:   arith.constant 0 : index
+//       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
+//       CHECK-D:     %[[V:.+]] = "test.dummy_op"
+//       CHECK-D:     %[[V1:.+]] = "test.dummy_op"
+//       CHECK-D:     vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<8x1xf32>
+//       CHECK-D:     vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<8xf32>
+//       CHECK-D:   }
+
+func.func @warp_extract_8_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+    %c0 = arith.constant 0 : index
+    %v = "test.dummy_op"() : () -> (vector<8xf32>)
+    %v1 = "test.dummy_op"() : () -> (vector<8x1xf32>)
+    vector.transfer_write %v1, %arg1[%c0, %c0] : vector<8x1xf32>, memref<1024x1024xf32>
+    vector.transfer_write %v, %arg1[%c0, %c0] : vector<8xf32>, memref<1024x1024xf32>
+  }
+  return
+}
+
+// -----
+
 // CHECK-PROP-LABEL:   func @warp_dead_result(
 func.func @warp_dead_result(%laneid: index) -> (vector<1xf32>) {
   // CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 86b8d5f9b0995..e593c0defcd29 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -568,6 +568,11 @@ struct TestVectorDistribution
       llvm::cl::desc("Test distribution of transfer write"),
       llvm::cl::init(false)};
 
+  Option<unsigned> maxTransferWriteElements{
+      *this, "max-transfer-write-elements",
+      llvm::cl::desc("Maximum number of transfer write elements to distribute"),
+      llvm::cl::init(1)};
+
   Option<bool> hoistUniform{*this, "hoist-uniform",
                             llvm::cl::desc("Test hoist uniform"),
                             llvm::cl::init(false)};
@@ -624,7 +629,8 @@ struct TestVectorDistribution
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     } else if (distributeTransferWriteOps) {
       RewritePatternSet patterns(ctx);
-      populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
+      populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
+                                                maxTransferWriteElements);
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     } else if (propagateDistribution) {
       RewritePatternSet patterns(ctx);

>From 4cd462684b38747272030d1e11974f086540452a Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 12 Dec 2023 00:50:55 -0500
Subject: [PATCH 2/2] Clarify max num elements to extect

---
 .../include/mlir/Dialect/Vector/Transforms/VectorDistribution.h | 2 +-
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp         | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index c7238fcbbb702..8907a2a583609 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -44,7 +44,7 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
 
 /// Distribute transfer_write ops based on the affine map returned by
 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
-/// will not be distributed.
+/// will not be distributed (it should be less than the warp size).
 ///
 /// Example:
 /// ```
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 1d35900a2c5db..074356ab42537 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -460,7 +460,7 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
 
 /// Distribute transfer_write ops based on the affine map returned by
 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
-/// will not be distributed.
+/// will not be distributed (it should be less than the warp size).
 ///
 /// Example:
 /// ```



More information about the Mlir-commits mailing list