[Mlir-commits] [mlir] [mlir][vector][nfc] Refactor vector.contract matvec tests (PR #72832)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 20 00:13:52 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
Update tests in "vector-contract-matvec-transforms.mlir" so that they
are consistent with similar tests in:
* "vector-contract-to-outerproduct-transforms.mlir".
This is to enable further refactoring in a follow-up patch, namely to:
* remove duplication (this will be much easier once consistent naming
is used),
* extend tests in "vector-contract-matvec-transforms.mlir" with cases
for scalable vectors,
* merge "vector-contract-matvec-transforms.mlir" and
"vector-contract-to-outerproduct-transforms.mlir" (there's no need
for 2 different files testing identical transformations).
Overview of changes in this patch:
1. Simplify the test by removing MemRef wrappers - this test verifies
Vector -> Vector transformations and MemRefs are not needed.
2. Use (m, k) indices instead of (i, j).
3. Rename function names.
This is part of a larger effort to improve test coverage for scalable
vectors in the Vector dialect.
---
Full diff: https://github.com/llvm/llvm-project/pull/72832.diff
1 Files Affected:
- (modified) mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir (+90-138)
``````````diff
diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index cfcb14a477b6b71..811fb589792b1a8 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
#matvec_accesses = [
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>
]
#matvec_trait = {
indexing_maps = #matvec_accesses,
@@ -16,9 +16,9 @@
}
#mattransvec_accesses = [
- affine_map<(i, j) -> (j, i)>,
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>
]
#mattransvec_trait = {
indexing_maps = #mattransvec_accesses,
@@ -26,9 +26,9 @@
}
#vecmat_accesses = [
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (m)>
]
#vecmat_trait = {
indexing_maps = #vecmat_accesses,
@@ -36,9 +36,9 @@
}
#vecmattrans_accesses = [
- affine_map<(i, j) -> (j)>,
- affine_map<(i, j) -> (j, i)>,
- affine_map<(i, j) -> (i)>
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (m)>
]
#vecmattrans_trait = {
indexing_maps = #vecmattrans_accesses,
@@ -46,166 +46,118 @@
}
#redpar_vecmattrans_accesses = [
- affine_map<(i, j) -> (i)>,
- affine_map<(i, j) -> (i, j)>,
- affine_map<(i, j) -> (j)>
+ affine_map<(m, k) -> (m)>,
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (k)>
]
#redpar_vecmattrans_trait = {
indexing_maps = #redpar_vecmattrans_accesses,
iterator_types = ["reduction", "parallel"]
}
-// CHECK-LABEL: func @matvec2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @matvecmax2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_mk_k_m_max
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @mattransvec2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @vecmat2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK-LABEL: func @matvec_k_mk_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @vecmattrans2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_k_km_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_k_km_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
-// CHECK-LABEL: func @redpar_vecmattrans2x2
-// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
-// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
-// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
-// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref<vector<2x2xf32>>
-// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref<vector<2xf32>>
-// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref<vector<2xf32>>
-// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : f32 from vector<2xf32>
+// CHECK-LABEL: func @matvec_m_mk_k
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: memref.store %[[T8]], %[[C]][] : memref<vector<2xf32>>
-// CHECK: return
-func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
- %arg2: memref<vector<2xf32>>) {
- %A = memref.load %arg0[] : memref<vector<2x2xf32>>
- %x = memref.load %arg1[] : memref<vector<2xf32>>
- %b = memref.load %arg2[] : memref<vector<2xf32>>
+func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- memref.store %0, %arg2[] : memref<vector<2xf32>>
- return
+ return %0 : vector<2xf32>
}
module attributes {transform.with_named_sequence} {
``````````
</details>
https://github.com/llvm/llvm-project/pull/72832
More information about the Mlir-commits
mailing list