[Mlir-commits] [mlir] [mlir][vector] Allow vector distribution with multiple written elements (PR #75122)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 11 16:58:40 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
Add a configuration option to allow vector distribution with multiple elements written by a single lane.
This is so that we can perform vector multi-reduction with multiple results per workgroup.
---
Full diff: https://github.com/llvm/llvm-project/pull/75122.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h (+4-2)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+17-9)
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+95-5)
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+7-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index f32efd94691a41..c7238fcbbb7023 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -43,7 +43,9 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
using DistributionMapFn = std::function<AffineMap(Value)>;
/// Distribute transfer_write ops based on the affine map returned by
-/// `distributionMapFn`.
+/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
+/// will not be distributed.
+///
/// Example:
/// ```
/// %0 = vector.warp_execute_on_lane_0(%id){
@@ -67,7 +69,7 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
/// distribute, meaning writes should propagate first.
void populateDistributeTransferWriteOpPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
- PatternBenefit benefit = 2);
+ unsigned maxNumElementsToExtract, PatternBenefit benefit = 2);
/// Move scalar operations with no dependency on the warp op outside of the
/// region.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 07ecd885752033..1d35900a2c5db7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -16,6 +16,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/FormatVariadic.h"
#include <numeric>
#include <utility>
@@ -458,7 +459,9 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
}
/// Distribute transfer_write ops based on the affine map returned by
-/// `distributionMapFn`.
+/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
+/// will not be distributed.
+///
/// Example:
/// ```
/// %0 = vector.warp_execute_on_lane_0(%id){
@@ -476,9 +479,10 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
- PatternBenefit b = 1)
+ unsigned maxNumElementsToExtract, PatternBenefit b = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
- distributionMapFn(std::move(fn)) {}
+ distributionMapFn(std::move(fn)),
+ maxNumElementsToExtract(maxNumElementsToExtract) {}
/// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
/// are multiples of the distribution ratio are supported at the moment.
@@ -553,10 +557,13 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
Location loc = writeOp.getLoc();
VectorType vecType = writeOp.getVectorType();
- // Only sink out vector of 1 element for now to not serialize large vector
- // store. This can later be controlled by user.
- if (vecType.getNumElements() != 1)
- return failure();
+ if (vecType.getNumElements() > maxNumElementsToExtract) {
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ llvm::formatv(
+ "writes more elements ({0}) than allowed to extract ({1})",
+ vecType.getNumElements(), maxNumElementsToExtract));
+ }
// Do not process warp ops that contain only TransferWriteOps.
if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
@@ -616,6 +623,7 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
private:
DistributionMapFn distributionMapFn;
+ unsigned maxNumElementsToExtract = 1;
};
/// Sink out elementwise op feeding into a warp op yield.
@@ -1833,9 +1841,9 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
void mlir::vector::populateDistributeTransferWriteOpPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
- PatternBenefit benefit) {
+ unsigned maxNumElementsToExtract, PatternBenefit benefit) {
patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
- benefit);
+ maxNumElementsToExtract, benefit);
}
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index ab175effa3dfb8..e04fa64f0f8a70 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1,8 +1,20 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s
-// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN: --test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN: --test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN: --test-vector-warp-distribute="hoist-uniform distribute-transfer-write max-transfer-write-elements=4" \
+// RUN: | FileCheck --check-prefixes=CHECK-D %s
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN: --test-vector-warp-distribute=propagate-distribution --canonicalize \
+// RUN: | FileCheck --check-prefixes=CHECK-PROP %s
+
+// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
+// RUN: --test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \
+// RUN: --canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
// CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)>
// CHECK-SCF-IF-DAG: #[[$TIMES4:.*]] = affine_map<()[s0] -> (s0 * 4)>
@@ -134,6 +146,84 @@ func.func @warp_extract(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : ind
// -----
+// Check that we can distribute writes of the maximum allowed number of elements.
+
+// CHECK-D-LABEL: func @warp_extract_4_elems(
+// CHECK-D: %[[WARPOP:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4x1xf32>)
+// CHECK-D: "test.dummy_op"
+// CHECK-D: "test.dummy_op"
+// CHECK-D: vector.yield %{{.*}}, %{{.*}} : vector<4xf32>, vector<4x1xf32>
+// CHECK-D: }
+// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
+// CHECK-D: vector.transfer_write %[[WARPOP]]#1, %{{.*}}[%{{.*}}] {{.*}} : vector<4x1xf32>
+// CHECK-D: }
+// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
+// CHECK-D: vector.transfer_write %[[WARPOP]]#0, %{{.*}}[%{{.*}}] {{.*}} : vector<4xf32>
+// CHECK-D: }
+
+func.func @warp_extract_4_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
+ vector.warp_execute_on_lane_0(%laneid)[32] {
+ %c0 = arith.constant 0 : index
+ %v = "test.dummy_op"() : () -> (vector<4xf32>)
+ %v1 = "test.dummy_op"() : () -> (vector<4x1xf32>)
+ vector.transfer_write %v1, %arg1[%c0, %c0] : vector<4x1xf32>, memref<1024x1024xf32>
+ vector.transfer_write %v, %arg1[%c0, %c0] : vector<4xf32>, memref<1024x1024xf32>
+ }
+ return
+}
+
+// -----
+
+// Check that we do not distribute writes larger than the maximum allowed
+// number of elements.
+
+// CHECK-D-LABEL: func @warp_extract_5_elems(
+// CHECK-D: arith.constant 0 : index
+// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
+// CHECK-D: %[[V:.+]] = "test.dummy_op"
+// CHECK-D: %[[V1:.+]] = "test.dummy_op"
+// CHECK-D: vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<5x1xf32>
+// CHECK-D: vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<5xf32>
+// CHECK-D: }
+
+func.func @warp_extract_5_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
+ vector.warp_execute_on_lane_0(%laneid)[32] {
+ %c0 = arith.constant 0 : index
+ %v = "test.dummy_op"() : () -> (vector<5xf32>)
+ %v1 = "test.dummy_op"() : () -> (vector<5x1xf32>)
+ vector.transfer_write %v1, %arg1[%c0, %c0] : vector<5x1xf32>, memref<1024x1024xf32>
+ vector.transfer_write %v, %arg1[%c0, %c0] : vector<5xf32>, memref<1024x1024xf32>
+ }
+ return
+}
+
+// -----
+
+// Check that we do not distribute writes larger than the maximum allowed
+// number of elements, or multiples of the maximum number of elements.
+
+// CHECK-D-LABEL: func @warp_extract_8_elems(
+// CHECK-D: arith.constant 0 : index
+// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
+// CHECK-D: %[[V:.+]] = "test.dummy_op"
+// CHECK-D: %[[V1:.+]] = "test.dummy_op"
+// CHECK-D: vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<8x1xf32>
+// CHECK-D: vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<8xf32>
+// CHECK-D: }
+
+func.func @warp_extract_8_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
+ vector.warp_execute_on_lane_0(%laneid)[32] {
+ %c0 = arith.constant 0 : index
+ %v = "test.dummy_op"() : () -> (vector<8xf32>)
+ %v1 = "test.dummy_op"() : () -> (vector<8x1xf32>)
+ vector.transfer_write %v1, %arg1[%c0, %c0] : vector<8x1xf32>, memref<1024x1024xf32>
+ vector.transfer_write %v, %arg1[%c0, %c0] : vector<8xf32>, memref<1024x1024xf32>
+ }
+ return
+}
+
+// -----
+
// CHECK-PROP-LABEL: func @warp_dead_result(
func.func @warp_dead_result(%laneid: index) -> (vector<1xf32>) {
// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 86b8d5f9b0995a..e593c0defcd29e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -568,6 +568,11 @@ struct TestVectorDistribution
llvm::cl::desc("Test distribution of transfer write"),
llvm::cl::init(false)};
+ Option<unsigned> maxTransferWriteElements{
+ *this, "max-transfer-write-elements",
+ llvm::cl::desc("Maximum number of transfer write elements to distribute"),
+ llvm::cl::init(1)};
+
Option<bool> hoistUniform{*this, "hoist-uniform",
llvm::cl::desc("Test hoist uniform"),
llvm::cl::init(false)};
@@ -624,7 +629,8 @@ struct TestVectorDistribution
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
} else if (distributeTransferWriteOps) {
RewritePatternSet patterns(ctx);
- populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
+ populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
+ maxTransferWriteElements);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
} else if (propagateDistribution) {
RewritePatternSet patterns(ctx);
``````````
</details>
https://github.com/llvm/llvm-project/pull/75122
More information about the Mlir-commits
mailing list