[Mlir-commits] [mlir] [mlir][linalg] Preserve cast semantics during generic to matmul (PR #174757)

Prathamesh Tagore llvmlistbot at llvm.org
Thu Jan 15 21:57:15 PST 2026


================
@@ -58,8 +62,118 @@ func.func @op_matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?x
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
+// Cast-auditing tests: ensure we only specialize when the cast semantics can
+// be expressed by linalg.matmul, and use the cast attribute when needed.
+
+// Check matmul with unsigned cast is correctly raised back to named op.
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+                                   %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i64, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.trunci %in_0 : i64 to i32
----------------
meshtag wrote:

We only consider casts as conflicting if they have different signedness behaviours, and then we do not specialise if they do conflict. Since this is not such a case, we do not block specialisation. Also the roundtrip lowering back to linalg.generic for such an op is expected to produce the same thing again, so we are not loosing information here.

For example: 
```
%0 = linalg.generic
         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
   ^bb0(%in: i16, %in_0: i64, %out: i32):
     %1 = arith.extui %in : i16 to i32
     %2 = arith.trunci %in_0 : i64 to i32
     %3 = arith.muli %1, %2 : i32
     %4 = arith.addi %out, %3 : i32
     linalg.yield %4 : i32
   } -> tensor<16x32xi32>
``` 
with `--linalg-specialize-generic-ops` becomes 
```
%0 = linalg.matmul {cast = #linalg.type_fn<cast_unsigned>} ins(%arg0, %arg1 : tensor<16x8xi16>, tensor<8x32xi64>) outs(%arg2 : tensor<16x32xi32>) -> tensor<16x32xi32>
```
and applying `-linalg-generalize-named-ops` on the above gives
```
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x8xi16>, tensor<8x32xi64>) outs(%arg2 : tensor<16x32xi32>) {
    ^bb0(%in: i16, %in_0: i64, %out: i32):
      %1 = arith.extui %in : i16 to i32
      %2 = arith.trunci %in_0 : i64 to i32
      %3 = arith.muli %1, %2 : i32
      %4 = arith.addi %out, %3 : i32
      linalg.yield %4 : i32
    } -> tensor<16x32xi32>
```

So we did not loose any information here.

https://github.com/llvm/llvm-project/pull/174757


More information about the Mlir-commits mailing list