[Mlir-commits] [mlir] [mlir][linalg] Preserve cast semantics during generic to matmul (PR #174757)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Jan 8 01:56:28 PST 2026
================
@@ -124,3 +124,126 @@ func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>
}
// CHECK-LABEL: op_matvec
// CHECK: linalg.generic
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i64, %out: i32):
+ %1 = arith.extui %in : i16 to i32
+ %2 = arith.trunci %in_0 : i64 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i16, %out: i32):
+ %1 = arith.extsi %in : i16 to i32
+ %2 = arith.extsi %in_0 : i16 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_signed_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i16, %out: i32):
+ %1 = arith.extui %in : i16 to i32
+ %2 = arith.extsi %in_0 : i16 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_mixed_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_output_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
----------------
adam-smnk wrote:
nit: use prefix `@negative_` for test cases that don't match, see other tests in the file
https://github.com/llvm/llvm-project/pull/174757
More information about the Mlir-commits
mailing list