[Mlir-commits] [mlir] [mlir][linalg] Preserve cast semantics during generic to matmul (PR #174757)

Adam Siemieniuk llvmlistbot at llvm.org
Fri Jan 16 07:46:29 PST 2026


================
@@ -58,8 +62,118 @@ func.func @op_matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?x
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
+// Cast-auditing tests: ensure we only specialize when the cast semantics can
+// be expressed by linalg.matmul, and use the cast attribute when needed.
+
+// Check matmul with unsigned cast is correctly raised back to named op.
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+                                   %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i64, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.trunci %in_0 : i64 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// Signed casts are the default, no cast attribute is required.
+func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                 %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i32):
+     %1 = arith.extsi %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_signed_cast
+// CHECK-NOT: linalg.generic
+// CHECK-NOT: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+// CHECK: linalg.matmul
+
+// Mixed signed/unsigned inputs cannot be encoded with a single cast attribute.
+func.func @negative_op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: negative_op_matmul_mixed_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// Output-side casts are not representable by the named matmul ops.
+func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32xi32>,
+                                 %Out: tensor<16x32xi64>) -> tensor<16x32xi64> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi32>, tensor<8x32xi32>) outs(%Out : tensor<16x32xi64>) {
+   ^bb0(%in: i32, %in_0: i32, %out: i64):
+     %3 = arith.trunci %out : i64 to i32
+     %4 = arith.muli %in, %in_0 : i32
+     %5 = arith.addi %3, %4 : i32
+     %6 = arith.extsi %5 : i32 to i64
+     linalg.yield %6 : i64
+   } -> tensor<16x32xi64>
+   return %0 : tensor<16x32xi64>
+}
+
+// CHECK-LABEL: negative_op_matmul_output_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// Bitcasts are not modeled by the cast attribute, but should not block
+// specialization.
----------------
adam-smnk wrote:

I'd think so. But unrelated to this PR.
Definitely sth to revisit.

https://github.com/llvm/llvm-project/pull/174757


More information about the Mlir-commits mailing list