[Mlir-commits] [mlir] 581672b - [mlir][AVX512] Add while loop-based sparse vector-vector dot product variants.
Matthias Springer
llvmlistbot at llvm.org
Mon Mar 15 01:00:04 PDT 2021
Author: Matthias Springer
Date: 2021-03-15T16:59:10+09:00
New Revision: 581672be04d15533caf7ec9830382219f78e4ce9
URL: https://github.com/llvm/llvm-project/commit/581672be04d15533caf7ec9830382219f78e4ce9
DIFF: https://github.com/llvm/llvm-project/commit/581672be04d15533caf7ec9830382219f78e4ce9.diff
LOG: [mlir][AVX512] Add while loop-based sparse vector-vector dot product variants.
Differential Revision: https://reviews.llvm.org/D98480
Added:
Modified:
mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir
Removed:
################################################################################
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir
index 9fadf0f79782..65c7357714de 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-sparse-dot-product.mlir
@@ -9,10 +9,17 @@
// Each sparse vector is represented by an index memref (A or C) and by a data
// memref (B or D), containing M or N elements.
//
-// There are two implementations:
+// There are four
diff erent implementations:
// * `memref_dot_simple`: Simple O(N*M) implementation with two for loops.
// * `memref_dot_optimized`: An optimized O(N*M) version of the previous
// implementation, where the second for loop skips over some elements.
+// * `memref_dot_while`: An optimized O(N+M) implementation that utilizes a
+// single while loop, coiterating over both vectors.
+// * `memref_dot_while_branchless`: An optimized O(N+M) implementation that
+// consists of a single while loop and has no branches within the loop.
+//
+// Output of llvm-mca:
+// https://gist.github.com/matthias-springer/72e7ee1b3c467e7aefb6e1fd862e4841
#contraction_accesses = [
affine_map<(i) -> (i)>,
@@ -224,6 +231,166 @@ func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
return %r0 : f64
}
+// Vector dot product with a while loop. Implemented as follows:
+//
+// r = 0.0, a = 0, b = 0
+// while (a < M && b < N) {
+// segA = A[a:a+8], segB = B[b:b+8]
+// if (segB[7] < segA[0]) b += 8
+// elif (segA[7] < segB[0]) a += 8
+// else {
+// r += vector_dot(...)
+// if (segA[7] < segB[7]) a += 8
+// elif (segB[7] < segA[7]) b += 8
+// else a += 8, b += 8
+// }
+// }
+func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
+ %m_C : memref<?xi64>, %m_D : memref<?xf64>,
+ %M : index, %N : index)
+ -> f64 {
+ // Helper constants for loops.
+ %c0 = constant 0 : index
+ %i0 = constant 0 : i32
+ %i7 = constant 7 : i32
+ %c8 = constant 8 : index
+
+ %data_zero = constant 0.0 : f64
+ %index_padding = constant 9223372036854775807 : i64
+
+ %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0)
+ : (f64, index, index) -> (f64, index, index) {
+ %cond_i = cmpi "slt", %a1, %M : index
+ %cond_j = cmpi "slt", %b1, %N : index
+ %cond = and %cond_i, %cond_j : i1
+ scf.condition(%cond) %r1, %a1, %b1 : f64, index, index
+ } do {
+ ^bb0(%r1 : f64, %a1 : index, %b1 : index):
+ // v_A, v_B, seg*_* could be part of the loop state to avoid a few
+ // redundant reads.
+ %v_A = vector.transfer_read %m_A[%a1], %index_padding
+ : memref<?xi64>, vector<8xi64>
+ %v_C = vector.transfer_read %m_C[%b1], %index_padding
+ : memref<?xi64>, vector<8xi64>
+
+ %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64>
+ %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
+ %segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64>
+ %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
+
+ %seg1_done = cmpi "slt", %segB_max, %segA_min : i64
+ %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) {
+ %b3 = addi %b1, %c8 : index
+ scf.yield %r1, %a1, %b3 : f64, index, index
+ } else {
+ %seg0_done = cmpi "slt", %segA_max, %segB_min : i64
+ %r4, %a4, %b4 = scf.if %seg0_done -> (f64, index, index) {
+ %a5 = addi %a1, %c8 : index
+ scf.yield %r1, %a5, %b1 : f64, index, index
+ } else {
+ %v_B = vector.transfer_read %m_B[%a1], %data_zero
+ : memref<?xf64>, vector<8xf64>
+ %v_D = vector.transfer_read %m_D[%b1], %data_zero
+ : memref<?xf64>, vector<8xf64>
+
+ %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D)
+ : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
+ -> f64
+ %r6 = addf %r1, %subresult : f64
+
+ %incr_a = cmpi "slt", %segA_max, %segB_max : i64
+ %a6, %b6 = scf.if %incr_a -> (index, index) {
+ %a7 = addi %a1, %c8 : index
+ scf.yield %a7, %b1 : index, index
+ } else {
+ %incr_b = cmpi "slt", %segB_max, %segA_max : i64
+ %a8, %b8 = scf.if %incr_b -> (index, index) {
+ %b9 = addi %b1, %c8 : index
+ scf.yield %a1, %b9 : index, index
+ } else {
+ %a10 = addi %a1, %c8 : index
+ %b10 = addi %b1, %c8 : index
+ scf.yield %a10, %b10 : index, index
+ }
+ scf.yield %a8, %b8 : index, index
+ }
+ scf.yield %r6, %a6, %b6 : f64, index, index
+ }
+ scf.yield %r4, %a4, %b4 : f64, index, index
+ }
+ scf.yield %r2, %a2, %b2 : f64, index, index
+ }
+
+ return %r0 : f64
+}
+
+// Vector dot product with a while loop that has no branches (apart from the
+// while loop itself). Implemented as follows:
+//
+// r = 0.0, a = 0, b = 0
+// while (a < M && b < N) {
+// segA = A[a:a+8], segB = B[b:b+8]
+// r += vector_dot(...)
+// a += (segA[7] <= segB[7]) * 8
+// b += (segB[7] <= segA[7]) * 8
+// }
+func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
+ %m_C : memref<?xi64>, %m_D : memref<?xf64>,
+ %M : index, %N : index)
+ -> f64 {
+ // Helper constants for loops.
+ %c0 = constant 0 : index
+ %i7 = constant 7 : i32
+ %c8 = constant 8 : index
+
+ %data_zero = constant 0.0 : f64
+ %index_padding = constant 9223372036854775807 : i64
+
+ %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0)
+ : (f64, index, index) -> (f64, index, index) {
+ %cond_i = cmpi "slt", %a1, %M : index
+ %cond_j = cmpi "slt", %b1, %N : index
+ %cond = and %cond_i, %cond_j : i1
+ scf.condition(%cond) %r1, %a1, %b1 : f64, index, index
+ } do {
+ ^bb0(%r1 : f64, %a1 : index, %b1 : index):
+ // v_A, v_B, seg*_* could be part of the loop state to avoid a few
+ // redundant reads.
+ %v_A = vector.transfer_read %m_A[%a1], %index_padding
+ : memref<?xi64>, vector<8xi64>
+ %v_B = vector.transfer_read %m_B[%a1], %data_zero
+ : memref<?xf64>, vector<8xf64>
+ %v_C = vector.transfer_read %m_C[%b1], %index_padding
+ : memref<?xi64>, vector<8xi64>
+ %v_D = vector.transfer_read %m_D[%b1], %data_zero
+ : memref<?xf64>, vector<8xf64>
+
+ %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D)
+ : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
+ -> f64
+ %r2 = addf %r1, %subresult : f64
+
+ %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
+ %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
+
+ %cond_a = cmpi "sle", %segA_max, %segB_max : i64
+ %cond_a_i64 = zexti %cond_a : i1 to i64
+ %cond_a_idx = index_cast %cond_a_i64 : i64 to index
+ %incr_a = muli %cond_a_idx, %c8 : index
+ %a2 = addi %a1, %incr_a : index
+
+ %cond_b = cmpi "sle", %segB_max, %segA_max : i64
+ %cond_b_i64 = zexti %cond_b : i1 to i64
+ %cond_b_idx = index_cast %cond_b_i64 : i64 to index
+ %incr_b = muli %cond_b_idx, %c8 : index
+ %b2 = addi %b1, %incr_b : index
+
+ scf.yield %r2, %a2, %b2 : f64, index, index
+ }
+
+ return %r0 : f64
+}
+
func @entry() -> i32 {
// Initialize large buffers that can be used for multiple test cases of
//
diff erent sizes.
@@ -256,6 +423,18 @@ func @entry() -> i32 {
vector.print %r1 : f64
// CHECK: 86
+ %r2 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
+ : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
+ index, index) -> f64
+ vector.print %r2 : f64
+ // CHECK: 86
+
+ %r6 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
+ : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
+ index, index) -> f64
+ vector.print %r6 : f64
+ // CHECK: 86
+
// --- Test case 2 ---.
// M and N must be a multiple of 8 if smaller than 128.
// (Because padding kicks in only for out-of-bounds accesses.)
@@ -275,6 +454,18 @@ func @entry() -> i32 {
vector.print %r4 : f64
// CHECK: 111
+ %r5 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
+ : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
+ index, index) -> f64
+ vector.print %r5 : f64
+ // CHECK: 111
+
+ %r7 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
+ : (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
+ index, index) -> f64
+ vector.print %r7 : f64
+ // CHECK: 111
+
// Release all resources.
dealloc %b_A : memref<128xi64>
dealloc %b_B : memref<128xf64>
More information about the Mlir-commits
mailing list