[Mlir-commits] [mlir] 581672b - [mlir][AVX512] Add while loop-based sparse vector-vector dot product variants.

Matthias Springer llvmlistbot at llvm.org
Mon Mar 15 01:00:04 PDT 2021


Author: Matthias Springer
Date: 2021-03-15T16:59:10+09:00
New Revision: 581672be04d15533caf7ec9830382219f78e4ce9

URL: https://github.com/llvm/llvm-project/commit/581672be04d15533caf7ec9830382219f78e4ce9
DIFF: https://github.com/llvm/llvm-project/commit/581672be04d15533caf7ec9830382219f78e4ce9.diff

LOG: [mlir][AVX512] Add while loop-based sparse vector-vector dot product variants.

Differential Revision: https://reviews.llvm.org/D98480

Added: 
    

Modified: 
    mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir
index 9fadf0f79782..65c7357714de 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir
@@ -9,10 +9,17 @@
 // Each sparse vector is represented by an index memref (A or C) and by a data
 // memref (B or D), containing M or N elements.
 //
-// There are two implementations:
+// There are four 
diff erent implementations:
 // * `memref_dot_simple`: Simple O(N*M) implementation with two for loops.
 // * `memref_dot_optimized`: An optimized O(N*M) version of the previous
 //   implementation, where the second for loop skips over some elements.
+// * `memref_dot_while`: An optimized O(N+M) implementation that utilizes a
+//   single while loop, coiterating over both vectors.
+// * `memref_dot_while_branchless`: An optimized O(N+M) implementation that
+//   consists of a single while loop and has no branches within the loop.
+//
+// Output of llvm-mca:
+// https://gist.github.com/matthias-springer/72e7ee1b3c467e7aefb6e1fd862e4841
 
 #contraction_accesses = [
  affine_map<(i) -> (i)>,
@@ -224,6 +231,166 @@ func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
   return %r0 : f64
 }
 
+// Vector dot product with a while loop. Implemented as follows:
+//
+// r = 0.0, a = 0, b = 0
+// while (a < M && b < N) {
+//   segA = A[a:a+8], segB = B[b:b+8]
+//   if   (segB[7] < segA[0]) b += 8
+//   elif (segA[7] < segB[0]) a += 8
+//   else {
+//     r += vector_dot(...)
+//     if   (segA[7] < segB[7]) a += 8
+//     elif (segB[7] < segA[7]) b += 8
+//     else                     a += 8, b += 8
+//   }
+// }
+func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
+                       %m_C : memref<?xi64>, %m_D : memref<?xf64>,
+                       %M : index, %N : index)
+    -> f64 {
+  // Helper constants for loops.
+  %c0 = constant 0 : index
+  %i0 = constant 0 : i32
+  %i7 = constant 7 : i32
+  %c8 = constant 8 : index
+
+  %data_zero = constant 0.0 : f64
+  %index_padding = constant 9223372036854775807 : i64
+
+  %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0)
+      : (f64, index, index) -> (f64, index, index) {
+    %cond_i = cmpi "slt", %a1, %M : index
+    %cond_j = cmpi "slt", %b1, %N : index
+    %cond = and %cond_i, %cond_j : i1
+    scf.condition(%cond) %r1, %a1, %b1 : f64, index, index
+  } do {
+  ^bb0(%r1 : f64, %a1 : index, %b1 : index):
+    // v_A, v_B, seg*_* could be part of the loop state to avoid a few
+    // redundant reads.
+    %v_A = vector.transfer_read %m_A[%a1], %index_padding
+        : memref<?xi64>, vector<8xi64>
+    %v_C = vector.transfer_read %m_C[%b1], %index_padding
+        : memref<?xi64>, vector<8xi64>
+
+    %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64>
+    %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
+    %segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64>
+    %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
+
+    %seg1_done = cmpi "slt", %segB_max, %segA_min : i64
+    %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) {
+      %b3 = addi %b1, %c8 : index
+      scf.yield %r1, %a1, %b3 : f64, index, index
+    } else {
+      %seg0_done = cmpi "slt", %segA_max, %segB_min : i64
+      %r4, %a4, %b4 = scf.if %seg0_done -> (f64, index, index) {
+        %a5 = addi %a1, %c8 : index
+        scf.yield %r1, %a5, %b1 : f64, index, index
+      } else {
+        %v_B = vector.transfer_read %m_B[%a1], %data_zero
+            : memref<?xf64>, vector<8xf64>
+        %v_D = vector.transfer_read %m_D[%b1], %data_zero
+            : memref<?xf64>, vector<8xf64>
+
+        %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D)
+            : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
+                -> f64
+        %r6 = addf %r1, %subresult : f64
+
+        %incr_a = cmpi "slt", %segA_max, %segB_max : i64
+        %a6, %b6 = scf.if %incr_a -> (index, index) {
+          %a7 = addi %a1, %c8 : index
+          scf.yield %a7, %b1 : index, index
+        } else {
+          %incr_b = cmpi "slt", %segB_max, %segA_max : i64
+          %a8, %b8 = scf.if %incr_b -> (index, index) {
+            %b9 = addi %b1, %c8 : index
+            scf.yield %a1, %b9 : index, index
+          } else {
+            %a10 = addi %a1, %c8 : index
+            %b10 = addi %b1, %c8 : index
+            scf.yield %a10, %b10 : index, index
+          }
+          scf.yield %a8, %b8 : index, index
+        }
+        scf.yield %r6, %a6, %b6 : f64, index, index
+      }
+      scf.yield %r4, %a4, %b4 : f64, index, index
+    }
+    scf.yield %r2, %a2, %b2 : f64, index, index
+  }
+
+  return %r0 : f64
+}
+
+// Vector dot product with a while loop that has no branches (apart from the
+// while loop itself). Implemented as follows:
+//
+// r = 0.0, a = 0, b = 0
+// while (a < M && b < N) {
+//   segA = A[a:a+8], segB = B[b:b+8]
+//   r += vector_dot(...)
+//   a += (segA[7] <= segB[7]) * 8
+//   b += (segB[7] <= segA[7]) * 8
+// }
+func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
+                                  %m_C : memref<?xi64>, %m_D : memref<?xf64>,
+                                  %M : index, %N : index)
+    -> f64 {
+  // Helper constants for loops.
+  %c0 = constant 0 : index
+  %i7 = constant 7 : i32
+  %c8 = constant 8 : index
+
+  %data_zero = constant 0.0 : f64
+  %index_padding = constant 9223372036854775807 : i64
+
+  %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0)
+      : (f64, index, index) -> (f64, index, index) {
+    %cond_i = cmpi "slt", %a1, %M : index
+    %cond_j = cmpi "slt", %b1, %N : index
+    %cond = and %cond_i, %cond_j : i1
+    scf.condition(%cond) %r1, %a1, %b1 : f64, index, index
+  } do {
+  ^bb0(%r1 : f64, %a1 : index, %b1 : index):
+    // v_A, v_B, seg*_* could be part of the loop state to avoid a few
+    // redundant reads.
+    %v_A = vector.transfer_read %m_A[%a1], %index_padding
+        : memref<?xi64>, vector<8xi64>
+    %v_B = vector.transfer_read %m_B[%a1], %data_zero
+        : memref<?xf64>, vector<8xf64>
+    %v_C = vector.transfer_read %m_C[%b1], %index_padding
+        : memref<?xi64>, vector<8xi64>
+    %v_D = vector.transfer_read %m_D[%b1], %data_zero
+        : memref<?xf64>, vector<8xf64>
+
+    %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D)
+        : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
+            -> f64
+    %r2 = addf %r1, %subresult : f64
+
+    %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
+    %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
+
+    %cond_a = cmpi "sle", %segA_max, %segB_max : i64
+    %cond_a_i64 = zexti %cond_a : i1 to i64
+    %cond_a_idx = index_cast %cond_a_i64 : i64 to index
+    %incr_a = muli %cond_a_idx, %c8 : index
+    %a2 = addi %a1, %incr_a : index
+
+    %cond_b = cmpi "sle", %segB_max, %segA_max : i64
+    %cond_b_i64 = zexti %cond_b : i1 to i64
+    %cond_b_idx = index_cast %cond_b_i64 : i64 to index
+    %incr_b = muli %cond_b_idx, %c8 : index
+    %b2 = addi %b1, %incr_b : index
+
+    scf.yield %r2, %a2, %b2 : f64, index, index
+  }
+
+  return %r0 : f64
+}
+
 func @entry() -> i32 {
   // Initialize large buffers that can be used for multiple test cases of
   // 
diff erent sizes.
@@ -256,6 +423,18 @@ func @entry() -> i32 {
   vector.print %r1 : f64
   // CHECK: 86
 
+  %r2 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
+      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
+         index, index) -> f64
+  vector.print %r2 : f64
+  // CHECK: 86
+
+  %r6 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
+      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
+         index, index) -> f64
+  vector.print %r6 : f64
+  // CHECK: 86
+
   // --- Test case 2 ---.
   // M and N must be a multiple of 8 if smaller than 128.
   // (Because padding kicks in only for out-of-bounds accesses.)
@@ -275,6 +454,18 @@ func @entry() -> i32 {
   vector.print %r4 : f64
   // CHECK: 111
 
+  %r5 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
+      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
+         index, index) -> f64
+  vector.print %r5 : f64
+  // CHECK: 111
+
+  %r7 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
+      : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
+         index, index) -> f64
+  vector.print %r7 : f64
+  // CHECK: 111
+
   // Release all resources.
   dealloc %b_A : memref<128xi64>
   dealloc %b_B : memref<128xf64>


        


More information about the Mlir-commits mailing list