[Mlir-commits] [mlir] [mlir][linalg] Preserve cast semantics during generic to matmul (PR #174757)

Adam Siemieniuk llvmlistbot at llvm.org
Thu Jan 8 01:56:28 PST 2026


================
@@ -124,3 +124,126 @@ func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>
 }
 // CHECK-LABEL: op_matvec
 // CHECK: linalg.generic
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+                                   %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i64, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.trunci %in_0 : i64 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                 %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i32):
+     %1 = arith.extsi %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_signed_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_mixed_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_output_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
----------------
adam-smnk wrote:

nit: use prefix `@negative_` for test cases that don't match, see other tests in the file

https://github.com/llvm/llvm-project/pull/174757


More information about the Mlir-commits mailing list