[Mlir-commits] [mlir] [mlir][vector] Add support for multi-dim reduction vector distribution (PR #71193)

Lei Zhang llvmlistbot at llvm.org
Thu Nov 9 22:48:03 PST 2023


================
@@ -496,6 +496,117 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
 
 // -----
 
+//   CHECK-PROP-LABEL:   func @warp_scf_for_multi_reduce(
+//     CHECK-PROP-NOT:   vector.warp_execute_on_lane_0
+//         CHECK-PROP:   scf.for {{.*}} -> (vector<1x4xf32>) {        
+//         CHECK-PROP:     scf.for {{.*}} -> (vector<1x4xf32>) {
+//         CHECK-PROP:       vector.transfer_read {{.*}} : memref<2x32x40x384xf32>, vector<1x4xf32> 
+//         CHECK-PROP:     }
+//         CHECK-PROP:   }
+//         CHECK-PROP:   vector.reduction <add>
+// CHECK-PROP-COUNT=8:   gpu.shuffle
+//
+//         CHECK-PROP:   scf.for {{.*}} {
+//         CHECK-PROP:     vector.transfer_read
+//         CHECK-PROP:     scf.for {{.*}} {
+//         CHECK-PROP:       vector.warp_execute_on_lane_0
+//         CHECK-PROP:         vector.transfer_read
+//         CHECK-PROP:         vector.transfer_write
+//         CHECK-PROP:       }
+//         CHECK-PROP:     }
+#map = affine_map<(d0, d1) -> (0, 0)>
+func.func @warp_scf_for_multi_reduce(%arg0: memref<2x32x40x384xf32>, %arg1: memref<2x32x40x384xf16>, %arg2: memref<2x32xf32>, %arg3: memref<2x32x40x384xf16>) {
+  %cst = arith.constant dense<1.536000e+04> : vector<8x128xf32>
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<8x128xf32>
+  %cst_1 = arith.constant 9.99999997E-7 : f32
+  %c128 = arith.constant 128 : index
+  %c8 = arith.constant 8 : index
+  %c0 = arith.constant 0 : index
+  %c40 = arith.constant 40 : index
+  %c384 = arith.constant 384 : index
+  %cst_2 = arith.constant 0.000000e+00 : f16
+  %cst_3 = arith.constant 0.000000e+00 : f32
+  %0 = gpu.thread_id  x
+  %1 = arith.truncf %cst_1 : f32 to f16
+  vector.warp_execute_on_lane_0(%0)[256] {
+    %2 = scf.for %arg4 = %c0 to %c40 step %c8 iter_args(%arg5 = %cst_0) -> (vector<8x128xf32>) {
+      %11 = scf.for %arg6 = %c0 to %c384 step %c128 iter_args(%arg7 = %arg5) -> (vector<8x128xf32>) {
+        %12 = vector.transfer_read %arg0[%c0, %c0, %arg4, %arg6], %cst_3 {in_bounds = [true, true]} : memref<2x32x40x384xf32>, vector<8x128xf32>
+        %13 = arith.addf %12, %arg7 : vector<8x128xf32>
+        scf.yield %13 : vector<8x128xf32>
+      }
+      scf.yield %11 : vector<8x128xf32>
+    }
+    %3 = vector.shape_cast %2 : vector<8x128xf32> to vector<1024xf32>
+    %4 = vector.reduction <add>, %3, %cst_3 : vector<1024xf32> into f32
+    %5 = vector.broadcast %4 : f32 to vector<8x128xf32>
+    %6 = arith.divf %5, %cst : vector<8x128xf32>
+    %7 = arith.truncf %6 : vector<8x128xf32> to vector<8x128xf16>
+    %8 = vector.broadcast %1 : f16 to vector<8x128xf16>
+    %9 = arith.addf %7, %8 : vector<8x128xf16>
+    %10 = math.rsqrt %9 : vector<8x128xf16>
+    scf.for %arg4 = %c0 to %c40 step %c8 {
+      %11 = vector.transfer_read %arg2[%c0, %c0], %cst_3 {in_bounds = [true, true], permutation_map = #map} : memref<2x32xf32>, vector<8x128xf32>
+      %12 = arith.truncf %11 : vector<8x128xf32> to vector<8x128xf16>
+      scf.for %arg5 = %c0 to %c384 step %c128 {
+        %13 = vector.transfer_read %arg1[%c0, %c0, %arg4, %arg5], %cst_2 {in_bounds = [true, true]} : memref<2x32x40x384xf16>, vector<8x128xf16>
+        %14 = arith.subf %13, %12 : vector<8x128xf16>
+        %15 = arith.mulf %14, %10 : vector<8x128xf16>
+        vector.transfer_write %15, %arg3[%c0, %c0, %arg4, %arg5] {in_bounds = [true, true]} : vector<8x128xf16>, memref<2x32x40x384xf16>
+      }
+    }
+  }
+  return
+}
+
+// -----
+
+//   CHECK-PROP-LABEL:   func @warp_multi_reduce_3d(
+//     CHECK-PROP-NOT:   vector.warp_execute_on_lane_0
+//         CHECK-PROP:   vector.transfer_read {{.*}} : memref<128x4x64xf32>, vector<1x2x64xf32>
+//         CHECK-PROP:   vector.shape_cast {{.*}} : vector<1x2x64xf32> to vector<128xf32>
+//         CHECK-PROP:   vector.reduction <add>, {{.*}} : vector<128xf32> into f32
+// CHECK-PROP-COUNT=8:   gpu.shuffle
+func.func @warp_multi_reduce_3d(%arg0 : memref<128x4x64xf32>) -> f32 {
+    %0 = gpu.thread_id x
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %ret = vector.warp_execute_on_lane_0(%0)[256] -> f32 {
+      %read = vector.transfer_read %arg0[%c0, %c0, %c0], %cst { in_bounds = [true, true, true] } : memref<128x4x64xf32>, vector<128x4x64xf32>
+      %cast = vector.shape_cast %read : vector<128x4x64xf32> to vector<32768xf32>
+      %out = vector.reduction <add>, %cast, %cst : vector<32768xf32> into f32
+      vector.yield %out : f32
+    }
+    func.return %ret : f32
+}
+
+// -----
+
+//   CHECK-PROP-LABEL:   func @warp_multi_dim_diff_read_cast(
+//     CHECK-PROP-NOT:   vector.warp_execute_on_lane_0
+//         CHECK-PROP:   vector.transfer_read {{.*}} : memref<2x4x16xf32>, vector<1x2x16xf32>
+//         CHECK-PROP:   vector.transfer_read {{.*}} : memref<128xf32>, vector<32xf32>
+//         CHECK-PROP:   vector.shape_cast {{.*}} : vector<1x2x16xf32> to vector<32xf32>
+//         CHECK-PROP:   arith.addf {{.*}} : vector<32xf32>
+//         CHECK-PROP:   vector.reduction <add>, {{.*}} : vector<32xf32> into f32
+// CHECK-PROP-COUNT=2:   gpu.shuffle
----------------
antiagainst wrote:

`CHECK-PROP-COUNT-2`

https://github.com/llvm/llvm-project/pull/71193


More information about the Mlir-commits mailing list