[Mlir-commits] [mlir] [mlir][amdgpu][rocdl] Add gfx1250 wmma ops (PR #165064)
Krzysztof Drewniak
llvmlistbot at llvm.org
Fri Oct 24 21:51:20 PDT 2025
================
@@ -161,14 +153,62 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec
// -----
func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
- // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+ // expected-error at +1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32, 64, 128}}}
%0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
func.return %0 : vector<8xi32>
}
// -----
-// Missinng `resetOffset`
+func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error at +1 {{'amdgpu.wmma' op source vectors have different lengths}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<16xf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error at +1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
----------------
krzysz00 wrote:
I think there's a `much` that should be `must` here. Probably an old typo of mine
https://github.com/llvm/llvm-project/pull/165064
More information about the Mlir-commits
mailing list