[Mlir-commits] [mlir] [MLIR][Vector] Refactor tests for contract -> OP transforms (2/N) (PR #73447)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 26 04:37:12 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
This is a direct follow-up of #<!-- -->73348. The matvec trait that's used for
`@<!-- -->matvec_m_mk_k` was incorrectly updated from:
```
affine_map<(m, k) -> (m)>,
affine_map<(m, k) -> (m, k)>,
affine_map<(m, k) -> (k)>
]
indexing_maps = #redpar_vecmattrans_accesses,
iterator_types = ["reduction", "parallel"]
}
```
to:
```
affine_map<(m, k) -> (k)>,
affine_map<(m, k) -> (k, m)>,
affine_map<(m, k) -> (m)>
]
indexing_maps = #matvec_accesses_4,
iterator_types = ["parallel", "reduction"]
}
```
Note that these traits describe identical matvec operation, hence the
`CHECK` lines are identical for both.
Also, `#redpar_vecmattrans_trait` is identical to `#matvec_trait_8`
that's already present in:
* "vector-contract-to-outerproduct-matvec-transforms.mlir"
For this reason:
* `@<!-- -->matvec_m_mk_k` is moved near other tests that already use `#matvec_trait_8`,
* `#redpar_vecmattrans_trait` is replaced `#matvec_trait_8`.
This is a part of a larger effort to add cases with scalable vectors to
tests for the Vector dialect. I am refactoring these tests so that it's
easier to identify what cases are tested and where to add tests for
scalable vectors.
Implements #<!-- -->72834.
---
Patch is 39.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73447.diff
1 Files Affected:
- (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir (+292-215)
``````````diff
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index 3ca3d344c1abe04..e84a43feaff39dc 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -1,5 +1,17 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
+/// Matvec operations:
+/// b += A * x.
+/// (b and x are 1-d vectors, A is a 2-d matrix). ATM three different variants
+/// are tested:
+/// * plain (no mask, fixed-wdith vectors),
+/// * masked (fixed-width vectors,
+/// * scalable (mask + scalable vectors).
+///
+/// TODO: These tests were extracted from 2 different files. If you find the
+/// formatting inconsistent, please update accordingly.
+
#matvec_accesses_1 = [
affine_map<(m, k) -> (m, k)>,
affine_map<(m, k) -> (k)>,
@@ -46,19 +58,67 @@
iterator_types = ["parallel", "reduction"]
}
-#redpar_vecmattrans_accesses = [
- affine_map<(m, k) -> (m)>,
- affine_map<(m, k) -> (m, k)>,
- affine_map<(m, k) -> (k)>
+#matvec_accesses_5 = [
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+ indexing_maps = #matvec_accesses_5,
+ iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_6 = [
+ affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+ indexing_maps = #matvec_accesses_6,
+ iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_7 = [
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+ indexing_maps = #matvec_accesses_7,
+ iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_8 = [
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (m)>
]
-#redpar_vecmattrans_trait = {
- indexing_maps = #redpar_vecmattrans_accesses,
+#matvec_trait_8 = {
+ indexing_maps = #matvec_accesses_8,
iterator_types = ["reduction", "parallel"]
}
// ============================================================================
// Matvec 1 (plain + masked + scalable)
// ============================================================================
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
// CHECK-LABEL: func.func @masked_matvec_mk_k_m(
// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
// CHECK-SAME: %{{.*}}: vector<3xf32>,
@@ -73,12 +133,11 @@
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
-
-func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
- %arg1: vector<3xf32>,
- %arg2: vector<2xf32>,
+func.func @masked_matvec_mk_k_m(%A: vector<2x3xf32>,
+ %x: vector<3xf32>,
+ %b: vector<2xf32>,
%m: vector<2x3xi1>) -> vector<2xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
: vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
return %0 : vector<2xf32>
}
@@ -97,46 +156,28 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
-func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
- %arg1: vector<3xf32>,
- %arg2: vector<[2]xf32>,
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%A: vector<[2]x3xf32>,
+ %x: vector<3xf32>,
+ %b: vector<[2]xf32>,
%m: vector<[2]x3xi1>) -> vector<[2]xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
: vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
return %0 : vector<[2]xf32>
}
-// CHECK-LABEL: func @matvec_mk_k_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
// ============================================================================
// Matvec 1 - max (plain)
// ============================================================================
// CHECK-LABEL: func @matvec_mk_k_m_max
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
%x: vector<2xf32>,
@@ -149,38 +190,38 @@ func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
// Matvec 2 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL: @masked_matvec_km_k_m
-// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<4xf32>,
+func.func @masked_matvec_km_k_m(%A: vector<2x4xf32>,
+ %x: vector<2xf32>,
+ %b: vector<4xf32>,
%mask: vector<4x2xi1>) -> vector<4xf32> {
// CHECK: vector.transpose %[[MASK]]
- // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
- vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+ vector.contract #matvec_trait_2 %A, %x, %b
: vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
} : vector<4x2xi1> -> vector<4xf32>
return %res : vector<4xf32>
}
// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<[4]xf32>,
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+ %x: vector<2xf32>,
+ %b: vector<[4]xf32>,
%mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
// CHECK: vector.transpose %[[MASK]]
- // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
%res = vector.mask %mask {
- vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+ vector.contract #matvec_trait_2 %A, %x, %b
: vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
} : vector<[4]x2xi1> -> vector<[4]xf32>
return %res : vector<[4]xf32>
@@ -188,13 +229,13 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
// CHECK-LABEL: func @matvec_km_k_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @matvec_km_k_m(%A: vector<2x2xf32>,
%x: vector<2xf32>,
@@ -207,54 +248,53 @@ func.func @matvec_km_k_m(%A: vector<2x2xf32>,
// Matvec 3 (plain + masked + scalable)
// ============================================================================
// CHECK-LABEL: @masked_matvec_k_mk_m
-// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[A:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<4xf32>,
+func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<4xf32>,
%mask: vector<4x2xi1>) -> vector<4xf32> {
// CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
- vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+ vector.contract #matvec_trait_3 %x, %A, %b
: vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
} : vector<4x2xi1> -> vector<4xf32>
return %res : vector<4xf32>
}
// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[A:.+]]: vector<[4]x2xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<[4]xf32>,
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<[4]xf32>,
%mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
// CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
%res = vector.mask %mask {
- vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+ vector.contract #matvec_trait_3 %x, %A, %b
: vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
} : vector<[4]x2xi1> -> vector<[4]xf32>
return %res : vector<[4]xf32>
}
-
// CHECK-LABEL: func @matvec_k_mk_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
%x: vector<2xf32>,
@@ -266,253 +306,290 @@ func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
// ============================================================================
// Matvec 4 (plain + masked + scalable)
// ============================================================================
+// CHECK-LABEL: func @matvec_k_km_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_k_km_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<[4]xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<[4]xf32>,
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+ %x: vector<2xf32>,
+ %b: vector<[4]xf32>,
%mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
// CHECK: vector.transpose %[[MASK]]
- // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[A]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
%res = vector.mask %mask {
- vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+ vector.contract #matvec_trait_4 %x, %A, %b
: vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
} : vector<[4]x2xi1> -> vector<[4]xf32>
return %res : vector<[4]xf32>
}
// CHECK-LABEL: @masked_matvec_k_km_m
-// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[X:.+]]: vector<2xf32>
+// CHECK-SAME: %[[B:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<4xf32>,
+func.func @masked_matvec_k_km_m(%A: vector<2x4xf32>,
+ %x: vector<2xf32>,
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/73447
More information about the Mlir-commits
mailing list