[Mlir-commits] [mlir] 845dc17 - [mlir][Vector] Support broadcast vector type in distribution of vector.warp_execute_on_lane_0.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Sep 13 08:18:56 PDT 2022
Author: Nicolas Vasilache
Date: 2022-09-13T08:18:47-07:00
New Revision: 845dc178c0bb36af76229f89da7e13e866b010cd
URL: https://github.com/llvm/llvm-project/commit/845dc178c0bb36af76229f89da7e13e866b010cd
DIFF: https://github.com/llvm/llvm-project/commit/845dc178c0bb36af76229f89da7e13e866b010cd.diff
LOG: [mlir][Vector] Support broadcast vector type in distribution of vector.warp_execute_on_lane_0.
This revision significantly improves and tests the broadcast behavior of vector.warp_execute_on_lane_0.
Previously, the implementation of the broadcast behavior of vector.warp_execute_on_lane_0
assumed that the broadcasted value was always of scalar type.
This is not necessarily the case.
Differential Revision: https://reviews.llvm.org/D133767
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2c757be09a6a0..3064bdf17bb79 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -67,6 +67,13 @@ struct DistributedLoadStoreHelper {
ArrayRef<Value>{laneId});
}
+ /// Create a store during the process of distributing the
+ /// `vector.warp_execute_on_thread_0` op.
+ /// Vector distribution assumes the following convention regarding the
+ /// temporary buffers that are created to transition values. This **must**
+ /// be properly specified in the `options.warpAllocationFn`:
+ /// 1. scalars of type T transit through a memref<1xT>.
+ /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
Operation *buildStore(RewriterBase &b, Location loc, Value val,
Value buffer) {
assert((val == distributedVal || val == sequentialVal) &&
@@ -75,7 +82,12 @@ struct DistributedLoadStoreHelper {
// Vector case must use vector::TransferWriteOp which will later lower to
// vector.store of memref.store depending on further lowerings.
if (val.getType().isa<VectorType>()) {
- SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
+ int64_t rank = sequentialVectorType.getRank();
+ if (rank == 0) {
+ return b.create<vector::TransferWriteOp>(loc, val, buffer, ValueRange{},
+ ArrayRef<bool>{});
+ }
+ SmallVector<Value> indices(rank, zero);
auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType);
assert(maybeDistributedDim && "must be able to deduce distributed dim");
if (val == distributedVal)
@@ -90,17 +102,41 @@ struct DistributedLoadStoreHelper {
return b.create<memref::StoreOp>(loc, val, buffer, zero);
}
+ /// Create a load during the process of distributing the
+ /// `vector.warp_execute_on_thread_0` op.
+ /// Vector distribution assumes the following convention regarding the
+ /// temporary buffers that are created to transition values. This **must**
+ /// be properly specified in the `options.warpAllocationFn`:
+ /// 1. scalars of type T transit through a memref<1xT>.
+ /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
+ ///
+ /// When broadcastMode is true, the load is not distributed to account for
+ /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op.
+ ///
+ /// Example:
+ ///
+ /// ```
+ /// %r = vector.warp_execute_on_lane_0(...) -> (f32) {
+ /// vector.yield %cst : f32
+ /// }
+ /// // Both types are f32. The constant %cst is broadcasted to all lanes.
+ /// ```
+ /// This behavior described in more detail in the documentation of the op.
Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer,
bool broadcastMode = false) {
- // When broadcastMode is true, this is a broadcast.
- // E.g.:
- // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
- // vector.yield %cst : f32
- // }
- // Both types are f32. The constant %cst is broadcasted to all lanes.
- // This is described in more detail in the documentation of the op.
- if (broadcastMode)
+ if (broadcastMode) {
+ // Broadcast mode may occur for either scalar or vector operands.
+ auto vectorType = type.dyn_cast<VectorType>();
+ auto shape = buffer.getType().cast<MemRefType>();
+ if (vectorType) {
+ SmallVector<bool> inBounds(shape.getRank(), true);
+ return b.create<vector::TransferReadOp>(
+ loc, vectorType, buffer,
+ /*indices=*/SmallVector<Value>(shape.getRank(), zero),
+ ArrayRef<bool>(inBounds.begin(), inBounds.end()));
+ }
return b.create<memref::LoadOp>(loc, buffer, zero);
+ }
// Other cases must be vector atm.
// Vector case must use vector::TransferReadOp which will later lower to
@@ -328,8 +364,10 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
helper.buildStore(rewriter, loc, distributedVal, buffer);
// Load sequential vector from buffer, inside the ifOp.
rewriter.setInsertionPointToStart(ifOp.thenBlock());
- bbArgReplacements.push_back(
- helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
+ bool broadcastMode =
+ (sequentialVal.getType() == distributedVal.getType());
+ bbArgReplacements.push_back(helper.buildLoad(
+ rewriter, loc, sequentialVal.getType(), buffer, broadcastMode));
}
// Step 3. Insert sync after all the stores and before all the loads.
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 5a70ae8a4994a..20bc623064f9f 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -691,3 +691,46 @@ func.func @dedup(%laneid: index, %v0: vector<4xf32>, %v1: vector<4xf32>)
// CHECK-PROP: return %[[SINGLE_RES]], %[[SINGLE_RES]] : vector<1xf32>, vector<1xf32>
return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
}
+
+// -----
+
+// CHECK-SCF-IF: func @warp_execute_has_broadcast_semantics
+func.func @warp_execute_has_broadcast_semantics(%laneid: index, %s0: f32, %v0: vector<f32>, %v1: vector<1xf32>, %v2: vector<1x1xf32>)
+ -> (f32, vector<f32>, vector<1xf32>, vector<1x1xf32>) {
+ // CHECK-SCF-IF-DAG: %[[C0:.*]] = arith.constant 0 : index
+
+ // CHECK-SCF-IF: scf.if{{.*}}{
+ %r:4 = vector.warp_execute_on_lane_0(%laneid)[32]
+ args(%s0, %v0, %v1, %v2 : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>) -> (f32, vector<f32>, vector<1xf32>, vector<1x1xf32>) {
+ ^bb0(%bs0: f32, %bv0: vector<f32>, %bv1: vector<1xf32>, %bv2: vector<1x1xf32>):
+
+ // CHECK-SCF-IF: vector.transfer_read {{.*}}[%[[C0]], %[[C0]]]{{.*}} {in_bounds = [true, true]} : memref<1x1xf32, 3>, vector<1x1xf32>
+ // CHECK-SCF-IF: vector.transfer_read {{.*}}[%[[C0]]]{{.*}} {in_bounds = [true]} : memref<1xf32, 3>, vector<1xf32>
+ // CHECK-SCF-IF: vector.transfer_read {{.*}}[]{{.*}} : memref<f32, 3>, vector<f32>
+ // CHECK-SCF-IF: memref.load {{.*}}[%[[C0]]] : memref<1xf32, 3>
+ // CHECK-SCF-IF: "some_def_0"(%{{.*}}) : (f32) -> f32
+ // CHECK-SCF-IF: "some_def_1"(%{{.*}}) : (vector<f32>) -> vector<f32>
+ // CHECK-SCF-IF: "some_def_1"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
+ // CHECK-SCF-IF: "some_def_1"(%{{.*}}) : (vector<1x1xf32>) -> vector<1x1xf32>
+ // CHECK-SCF-IF: memref.store {{.*}}[%[[C0]]] : memref<1xf32, 3>
+ // CHECK-SCF-IF: vector.transfer_write {{.*}}[] : vector<f32>, memref<f32, 3>
+ // CHECK-SCF-IF: vector.transfer_write {{.*}}[%[[C0]]] {in_bounds = [true]} : vector<1xf32>, memref<1xf32, 3>
+ // CHECK-SCF-IF: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x1xf32>, memref<1x1xf32, 3>
+
+ %rs0 = "some_def_0"(%bs0) : (f32) -> f32
+ %rv0 = "some_def_1"(%bv0) : (vector<f32>) -> vector<f32>
+ %rv1 = "some_def_1"(%bv1) : (vector<1xf32>) -> vector<1xf32>
+ %rv2 = "some_def_1"(%bv2) : (vector<1x1xf32>) -> vector<1x1xf32>
+
+ // CHECK-SCF-IF-NOT: vector.yield
+ vector.yield %rs0, %rv0, %rv1, %rv2 : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>
+ }
+
+ // CHECK-SCF-IF: gpu.barrier
+ // CHECK-SCF-IF: %[[RV2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]]{{.*}} {in_bounds = [true, true]} : memref<1x1xf32, 3>, vector<1x1xf32>
+ // CHECK-SCF-IF: %[[RV1:.*]] = vector.transfer_read {{.*}}[%[[C0]]]{{.*}} {in_bounds = [true]} : memref<1xf32, 3>, vector<1xf32>
+ // CHECK-SCF-IF: %[[RV0:.*]] = vector.transfer_read {{.*}}[]{{.*}} : memref<f32, 3>, vector<f32>
+ // CHECK-SCF-IF: %[[RS0:.*]] = memref.load {{.*}}[%[[C0]]] : memref<1xf32, 3>
+ // CHECK-SCF-IF: return %[[RS0]], %[[RV0]], %[[RV1]], %[[RV2]] : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>
+ return %r#0, %r#1, %r#2, %r#3 : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>
+}
More information about the Mlir-commits
mailing list