[Mlir-commits] [mlir] [mlir][vector] Allow vector distribution with multiple written elements (PR #75122)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 11 16:58:40 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Jakub Kuderski (kuhar)

<details>
<summary>Changes</summary>

Add a configuration option to allow vector distribution with multiple elements written by a single lane.

This is so that we can perform vector multi-reduction with multiple results per workgroup.

---
Full diff: https://github.com/llvm/llvm-project/pull/75122.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h (+4-2) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+17-9) 
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+95-5) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+7-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index f32efd94691a41..c7238fcbbb7023 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -43,7 +43,9 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
 using DistributionMapFn = std::function<AffineMap(Value)>;
 
 /// Distribute transfer_write ops based on the affine map returned by
-/// `distributionMapFn`.
+/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
+/// will not be distributed.
+///
 /// Example:
 /// ```
 /// %0 = vector.warp_execute_on_lane_0(%id){
@@ -67,7 +69,7 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
 /// distribute, meaning writes should propagate first.
 void populateDistributeTransferWriteOpPatterns(
     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
-    PatternBenefit benefit = 2);
+    unsigned maxNumElementsToExtract, PatternBenefit benefit = 2);
 
 /// Move scalar operations with no dependency on the warp op outside of the
 /// region.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 07ecd885752033..1d35900a2c5db7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/Support/FormatVariadic.h"
 #include <numeric>
 #include <utility>
 
@@ -458,7 +459,9 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
 }
 
 /// Distribute transfer_write ops based on the affine map returned by
-/// `distributionMapFn`.
+/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
+/// will not be distributed.
+///
 /// Example:
 /// ```
 /// %0 = vector.warp_execute_on_lane_0(%id){
@@ -476,9 +479,10 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
 struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
-                      PatternBenefit b = 1)
+                      unsigned maxNumElementsToExtract, PatternBenefit b = 1)
       : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
-        distributionMapFn(std::move(fn)) {}
+        distributionMapFn(std::move(fn)),
+        maxNumElementsToExtract(maxNumElementsToExtract) {}
 
   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
   /// are multiples of the distribution ratio are supported at the moment.
@@ -553,10 +557,13 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
     Location loc = writeOp.getLoc();
     VectorType vecType = writeOp.getVectorType();
 
-    // Only sink out vector of 1 element for now to not serialize large vector
-    // store. This can later be controlled by user.
-    if (vecType.getNumElements() != 1)
-      return failure();
+    if (vecType.getNumElements() > maxNumElementsToExtract) {
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          llvm::formatv(
+              "writes more elements ({0}) than allowed to extract ({1})",
+              vecType.getNumElements(), maxNumElementsToExtract));
+    }
 
     // Do not process warp ops that contain only TransferWriteOps.
     if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
@@ -616,6 +623,7 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
 private:
   DistributionMapFn distributionMapFn;
+  unsigned maxNumElementsToExtract = 1;
 };
 
 /// Sink out elementwise op feeding into a warp op yield.
@@ -1833,9 +1841,9 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
 
 void mlir::vector::populateDistributeTransferWriteOpPatterns(
     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
-    PatternBenefit benefit) {
+    unsigned maxNumElementsToExtract, PatternBenefit benefit) {
   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
-                                    benefit);
+                                    maxNumElementsToExtract, benefit);
 }
 
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index ab175effa3dfb8..e04fa64f0f8a70 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1,8 +1,20 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN:   --test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN:   --test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN:   --test-vector-warp-distribute="hoist-uniform distribute-transfer-write max-transfer-write-elements=4" \
+// RUN:   | FileCheck --check-prefixes=CHECK-D %s
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN:  --test-vector-warp-distribute=propagate-distribution --canonicalize \
+// RUN:  | FileCheck --check-prefixes=CHECK-PROP %s
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN:   --test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \
+// RUN:   --canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
 
 // CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)>
 // CHECK-SCF-IF-DAG: #[[$TIMES4:.*]] = affine_map<()[s0] -> (s0 * 4)>
@@ -134,6 +146,84 @@ func.func @warp_extract(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : ind
 
 // -----
 
+// Check that we can distribute writes of the maximum allowed number of elements.
+
+// CHECK-D-LABEL: func @warp_extract_4_elems(
+//       CHECK-D:   %[[WARPOP:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4x1xf32>)
+//       CHECK-D:     "test.dummy_op"
+//       CHECK-D:     "test.dummy_op"
+//       CHECK-D:     vector.yield %{{.*}}, %{{.*}} : vector<4xf32>, vector<4x1xf32>
+//       CHECK-D:   }
+//       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
+//       CHECK-D:     vector.transfer_write %[[WARPOP]]#1, %{{.*}}[%{{.*}}] {{.*}} : vector<4x1xf32>
+//       CHECK-D:   }
+//       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
+//       CHECK-D:     vector.transfer_write %[[WARPOP]]#0, %{{.*}}[%{{.*}}] {{.*}} : vector<4xf32>
+//       CHECK-D:   }
+
+func.func @warp_extract_4_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+    %c0 = arith.constant 0 : index
+    %v = "test.dummy_op"() : () -> (vector<4xf32>)
+    %v1 = "test.dummy_op"() : () -> (vector<4x1xf32>)
+    vector.transfer_write %v1, %arg1[%c0, %c0] : vector<4x1xf32>, memref<1024x1024xf32>
+    vector.transfer_write %v, %arg1[%c0, %c0] : vector<4xf32>, memref<1024x1024xf32>
+  }
+  return
+}
+
+// -----
+
+// Check that we do not distribute writes larger than the maximum allowed
+// number of elements.
+
+// CHECK-D-LABEL: func @warp_extract_5_elems(
+//       CHECK-D:   arith.constant 0 : index
+//       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
+//       CHECK-D:     %[[V:.+]] = "test.dummy_op"
+//       CHECK-D:     %[[V1:.+]] = "test.dummy_op"
+//       CHECK-D:     vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<5x1xf32>
+//       CHECK-D:     vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<5xf32>
+//       CHECK-D:   }
+
+func.func @warp_extract_5_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+    %c0 = arith.constant 0 : index
+    %v = "test.dummy_op"() : () -> (vector<5xf32>)
+    %v1 = "test.dummy_op"() : () -> (vector<5x1xf32>)
+    vector.transfer_write %v1, %arg1[%c0, %c0] : vector<5x1xf32>, memref<1024x1024xf32>
+    vector.transfer_write %v, %arg1[%c0, %c0] : vector<5xf32>, memref<1024x1024xf32>
+  }
+  return
+}
+
+// -----
+
+// Check that we do not distribute writes larger than the maximum allowed
+// number of elements, or multiples of the maximum number of elements.
+
+// CHECK-D-LABEL: func @warp_extract_8_elems(
+//       CHECK-D:   arith.constant 0 : index
+//       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
+//       CHECK-D:     %[[V:.+]] = "test.dummy_op"
+//       CHECK-D:     %[[V1:.+]] = "test.dummy_op"
+//       CHECK-D:     vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<8x1xf32>
+//       CHECK-D:     vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<8xf32>
+//       CHECK-D:   }
+
+func.func @warp_extract_8_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
+  vector.warp_execute_on_lane_0(%laneid)[32] {
+    %c0 = arith.constant 0 : index
+    %v = "test.dummy_op"() : () -> (vector<8xf32>)
+    %v1 = "test.dummy_op"() : () -> (vector<8x1xf32>)
+    vector.transfer_write %v1, %arg1[%c0, %c0] : vector<8x1xf32>, memref<1024x1024xf32>
+    vector.transfer_write %v, %arg1[%c0, %c0] : vector<8xf32>, memref<1024x1024xf32>
+  }
+  return
+}
+
+// -----
+
 // CHECK-PROP-LABEL:   func @warp_dead_result(
 func.func @warp_dead_result(%laneid: index) -> (vector<1xf32>) {
   // CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 86b8d5f9b0995a..e593c0defcd29e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -568,6 +568,11 @@ struct TestVectorDistribution
       llvm::cl::desc("Test distribution of transfer write"),
       llvm::cl::init(false)};
 
+  Option<unsigned> maxTransferWriteElements{
+      *this, "max-transfer-write-elements",
+      llvm::cl::desc("Maximum number of transfer write elements to distribute"),
+      llvm::cl::init(1)};
+
   Option<bool> hoistUniform{*this, "hoist-uniform",
                             llvm::cl::desc("Test hoist uniform"),
                             llvm::cl::init(false)};
@@ -624,7 +629,8 @@ struct TestVectorDistribution
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     } else if (distributeTransferWriteOps) {
       RewritePatternSet patterns(ctx);
-      populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
+      populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
+                                                maxTransferWriteElements);
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     } else if (propagateDistribution) {
       RewritePatternSet patterns(ctx);

``````````

</details>


https://github.com/llvm/llvm-project/pull/75122


More information about the Mlir-commits mailing list