[Mlir-commits] [mlir] [mlir][linalg] Preserve cast semantics during generic to matmul (PR #174757)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Jan 16 07:23:04 PST 2026
================
@@ -58,8 +62,118 @@ func.func @op_matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?x
// CHECK-NOT: linalg.generic
// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// Cast-auditing tests: ensure we only specialize when the cast semantics can
+// be expressed by linalg.matmul, and use the cast attribute when needed.
+
+// Check matmul with unsigned cast is correctly raised back to named op.
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i64, %out: i32):
+ %1 = arith.extui %in : i16 to i32
+ %2 = arith.trunci %in_0 : i64 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// Signed casts are the default, no cast attribute is required.
+func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i16, %out: i32):
+ %1 = arith.extsi %in : i16 to i32
+ %2 = arith.extsi %in_0 : i16 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_signed_cast
+// CHECK-NOT: linalg.generic
+// CHECK-NOT: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+// CHECK: linalg.matmul
+
+// Mixed signed/unsigned inputs cannot be encoded with a single cast attribute.
+func.func @negative_op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i16, %out: i32):
+ %1 = arith.extui %in : i16 to i32
+ %2 = arith.extsi %in_0 : i16 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: negative_op_matmul_mixed_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// Output-side casts are not representable by the named matmul ops.
+func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32xi32>,
+ %Out: tensor<16x32xi64>) -> tensor<16x32xi64> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi32>, tensor<8x32xi32>) outs(%Out : tensor<16x32xi64>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i64):
+ %3 = arith.trunci %out : i64 to i32
+ %4 = arith.muli %in, %in_0 : i32
+ %5 = arith.addi %3, %4 : i32
+ %6 = arith.extsi %5 : i32 to i64
+ linalg.yield %6 : i64
+ } -> tensor<16x32xi64>
+ return %0 : tensor<16x32xi64>
+}
+
+// CHECK-LABEL: negative_op_matmul_output_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// Bitcasts are not modeled by the cast attribute, but should not block
+// specialization.
----------------
banach-space wrote:
That looks like a bug, no?
https://github.com/llvm/llvm-project/pull/174757
More information about the Mlir-commits
mailing list