[Mlir-commits] [mlir] [mlir][nfc] Add tests for linalg.mmt4d (PR #81422)

Cullen Rhodes llvmlistbot at llvm.org
Tue Feb 13 03:21:16 PST 2024


================
@@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x
     -> tensor<2x16xf32>
   return %1 : tensor<2x16xf32>
 }
+
+// -----
+
+func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
+                               %B: tensor<16x16x8x1xf32>,
+                               %C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> {
+    // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}}
+    %res = linalg.mmt4d
+                     ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                     outs(%C_in: tensor<16x16x8x1xf32>)
+                     -> tensor<16x16x8x1xf32>
+    return %res : tensor<16x16x8x1xf32>
+}
+
+// -----
+
+func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
+                 %B: tensor<16x16x8x1xf32>,
+                 %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
----------------
c-rhodes wrote:

nit: alignment
```suggestion
func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
                               %B: tensor<16x16x8x1xf32>,
                               %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
```

https://github.com/llvm/llvm-project/pull/81422


More information about the Mlir-commits mailing list