[Mlir-commits] [mlir] [mlir][nfc] Add tests for linalg.mmt4d (PR #81422)
Cullen Rhodes
llvmlistbot at llvm.org
Tue Feb 13 03:21:16 PST 2024
================
@@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x
-> tensor<2x16xf32>
return %1 : tensor<2x16xf32>
}
+
+// -----
+
+func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
+ %B: tensor<16x16x8x1xf32>,
+ %C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> {
+ // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}}
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<16x16x8x1xf32>)
+ -> tensor<16x16x8x1xf32>
+ return %res : tensor<16x16x8x1xf32>
+}
+
+// -----
+
+func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
+ %B: tensor<16x16x8x1xf32>,
+ %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
----------------
c-rhodes wrote:
nit: alignment
```suggestion
func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
%B: tensor<16x16x8x1xf32>,
%C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
```
https://github.com/llvm/llvm-project/pull/81422
More information about the Mlir-commits
mailing list