[Mlir-commits] [mlir] 845dc17 - [mlir][Vector] Support broadcast vector type in distribution of vector.warp_execute_on_lane_0.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Sep 13 08:18:56 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-13T08:18:47-07:00
New Revision: 845dc178c0bb36af76229f89da7e13e866b010cd

URL: https://github.com/llvm/llvm-project/commit/845dc178c0bb36af76229f89da7e13e866b010cd
DIFF: https://github.com/llvm/llvm-project/commit/845dc178c0bb36af76229f89da7e13e866b010cd.diff

LOG: [mlir][Vector] Support broadcast vector type in  distribution of vector.warp_execute_on_lane_0.

This revision significantly improves and tests the broadcast behavior of vector.warp_execute_on_lane_0.

Previously, the implementation of the broadcast behavior of vector.warp_execute_on_lane_0
assumed that the broadcasted value was always of scalar type.

This is not necessarily the case.

Differential Revision: https://reviews.llvm.org/D133767

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2c757be09a6a0..3064bdf17bb79 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -67,6 +67,13 @@ struct DistributedLoadStoreHelper {
                                          ArrayRef<Value>{laneId});
   }
 
+  /// Create a store during the process of distributing the
+  /// `vector.warp_execute_on_thread_0` op.
+  /// Vector distribution assumes the following convention regarding the
+  /// temporary buffers that are created to transition values. This **must**
+  /// be properly specified in the `options.warpAllocationFn`:
+  ///   1. scalars of type T transit through a memref<1xT>.
+  ///   2. vectors of type V<shapexT> transit through a memref<shapexT>
   Operation *buildStore(RewriterBase &b, Location loc, Value val,
                         Value buffer) {
     assert((val == distributedVal || val == sequentialVal) &&
@@ -75,7 +82,12 @@ struct DistributedLoadStoreHelper {
     // Vector case must use vector::TransferWriteOp which will later lower to
     //   vector.store of memref.store depending on further lowerings.
     if (val.getType().isa<VectorType>()) {
-      SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
+      int64_t rank = sequentialVectorType.getRank();
+      if (rank == 0) {
+        return b.create<vector::TransferWriteOp>(loc, val, buffer, ValueRange{},
+                                                 ArrayRef<bool>{});
+      }
+      SmallVector<Value> indices(rank, zero);
       auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType);
       assert(maybeDistributedDim && "must be able to deduce distributed dim");
       if (val == distributedVal)
@@ -90,17 +102,41 @@ struct DistributedLoadStoreHelper {
     return b.create<memref::StoreOp>(loc, val, buffer, zero);
   }
 
+  /// Create a load during the process of distributing the
+  /// `vector.warp_execute_on_thread_0` op.
+  /// Vector distribution assumes the following convention regarding the
+  /// temporary buffers that are created to transition values. This **must**
+  /// be properly specified in the `options.warpAllocationFn`:
+  ///   1. scalars of type T transit through a memref<1xT>.
+  ///   2. vectors of type V<shapexT> transit through a memref<shapexT>
+  ///
+  /// When broadcastMode is true, the load is not distributed to account for
+  /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op.
+  ///
+  /// Example:
+  ///
+  /// ```
+  ///   %r = vector.warp_execute_on_lane_0(...) -> (f32) {
+  ///     vector.yield %cst : f32
+  ///   }
+  ///   // Both types are f32. The constant %cst is broadcasted to all lanes.
+  /// ```
+  /// This behavior described in more detail in the documentation of the op.
   Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer,
                   bool broadcastMode = false) {
-    // When broadcastMode is true, this is a broadcast.
-    // E.g.:
-    // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
-    //   vector.yield %cst : f32
-    // }
-    // Both types are f32. The constant %cst is broadcasted to all lanes.
-    // This is described in more detail in the documentation of the op.
-    if (broadcastMode)
+    if (broadcastMode) {
+      // Broadcast mode may occur for either scalar or vector operands.
+      auto vectorType = type.dyn_cast<VectorType>();
+      auto shape = buffer.getType().cast<MemRefType>();
+      if (vectorType) {
+        SmallVector<bool> inBounds(shape.getRank(), true);
+        return b.create<vector::TransferReadOp>(
+            loc, vectorType, buffer,
+            /*indices=*/SmallVector<Value>(shape.getRank(), zero),
+            ArrayRef<bool>(inBounds.begin(), inBounds.end()));
+      }
       return b.create<memref::LoadOp>(loc, buffer, zero);
+    }
 
     // Other cases must be vector atm.
     // Vector case must use vector::TransferReadOp which will later lower to
@@ -328,8 +364,10 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
       helper.buildStore(rewriter, loc, distributedVal, buffer);
       // Load sequential vector from buffer, inside the ifOp.
       rewriter.setInsertionPointToStart(ifOp.thenBlock());
-      bbArgReplacements.push_back(
-          helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
+      bool broadcastMode =
+          (sequentialVal.getType() == distributedVal.getType());
+      bbArgReplacements.push_back(helper.buildLoad(
+          rewriter, loc, sequentialVal.getType(), buffer, broadcastMode));
     }
 
     // Step 3. Insert sync after all the stores and before all the loads.

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 5a70ae8a4994a..20bc623064f9f 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -691,3 +691,46 @@ func.func @dedup(%laneid: index, %v0: vector<4xf32>, %v1: vector<4xf32>)
   // CHECK-PROP: return %[[SINGLE_RES]], %[[SINGLE_RES]] : vector<1xf32>, vector<1xf32>
   return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
 }
+
+// -----
+
+// CHECK-SCF-IF:   func @warp_execute_has_broadcast_semantics
+func.func @warp_execute_has_broadcast_semantics(%laneid: index, %s0: f32, %v0: vector<f32>, %v1: vector<1xf32>, %v2: vector<1x1xf32>) 
+    -> (f32, vector<f32>, vector<1xf32>, vector<1x1xf32>) {
+  // CHECK-SCF-IF-DAG: %[[C0:.*]] = arith.constant 0 : index
+
+  // CHECK-SCF-IF: scf.if{{.*}}{
+  %r:4 = vector.warp_execute_on_lane_0(%laneid)[32]
+      args(%s0, %v0, %v1, %v2 : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>) -> (f32, vector<f32>, vector<1xf32>, vector<1x1xf32>) {
+    ^bb0(%bs0: f32, %bv0: vector<f32>, %bv1: vector<1xf32>, %bv2: vector<1x1xf32>):
+
+      // CHECK-SCF-IF: vector.transfer_read {{.*}}[%[[C0]], %[[C0]]]{{.*}} {in_bounds = [true, true]} : memref<1x1xf32, 3>, vector<1x1xf32>
+      // CHECK-SCF-IF: vector.transfer_read {{.*}}[%[[C0]]]{{.*}} {in_bounds = [true]} : memref<1xf32, 3>, vector<1xf32>
+      // CHECK-SCF-IF: vector.transfer_read {{.*}}[]{{.*}} : memref<f32, 3>, vector<f32>
+      // CHECK-SCF-IF: memref.load {{.*}}[%[[C0]]] : memref<1xf32, 3>
+      // CHECK-SCF-IF: "some_def_0"(%{{.*}}) : (f32) -> f32
+      // CHECK-SCF-IF: "some_def_1"(%{{.*}}) : (vector<f32>) -> vector<f32>
+      // CHECK-SCF-IF: "some_def_1"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
+      // CHECK-SCF-IF: "some_def_1"(%{{.*}}) : (vector<1x1xf32>) -> vector<1x1xf32>
+      // CHECK-SCF-IF: memref.store {{.*}}[%[[C0]]] : memref<1xf32, 3>
+      // CHECK-SCF-IF: vector.transfer_write {{.*}}[] : vector<f32>, memref<f32, 3>
+      // CHECK-SCF-IF: vector.transfer_write {{.*}}[%[[C0]]] {in_bounds = [true]} : vector<1xf32>, memref<1xf32, 3>
+      // CHECK-SCF-IF: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x1xf32>, memref<1x1xf32, 3>
+
+      %rs0 = "some_def_0"(%bs0) : (f32) -> f32
+      %rv0 = "some_def_1"(%bv0) : (vector<f32>) -> vector<f32>
+      %rv1 = "some_def_1"(%bv1) : (vector<1xf32>) -> vector<1xf32>
+      %rv2 = "some_def_1"(%bv2) : (vector<1x1xf32>) -> vector<1x1xf32>
+
+      // CHECK-SCF-IF-NOT: vector.yield
+      vector.yield %rs0, %rv0, %rv1, %rv2 : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>
+  }
+
+  // CHECK-SCF-IF: gpu.barrier
+  // CHECK-SCF-IF: %[[RV2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]]{{.*}} {in_bounds = [true, true]} : memref<1x1xf32, 3>, vector<1x1xf32>
+  // CHECK-SCF-IF: %[[RV1:.*]] = vector.transfer_read {{.*}}[%[[C0]]]{{.*}} {in_bounds = [true]} : memref<1xf32, 3>, vector<1xf32>
+  // CHECK-SCF-IF: %[[RV0:.*]] = vector.transfer_read {{.*}}[]{{.*}} : memref<f32, 3>, vector<f32>
+  // CHECK-SCF-IF: %[[RS0:.*]] = memref.load {{.*}}[%[[C0]]] : memref<1xf32, 3>
+  // CHECK-SCF-IF: return %[[RS0]], %[[RV0]], %[[RV1]], %[[RV2]] : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>
+  return %r#0, %r#1, %r#2, %r#3 : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>
+}


        


More information about the Mlir-commits mailing list