[Mlir-commits] [mlir] [mlir][linalg] Add mixed precision folding pattern in vectorize_children_and_apply_patterns TD Op (PR #148684)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Aug 5 07:19:03 PDT 2025


================
@@ -1777,3 +1777,158 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Mixed precision vectorization tests.
+
+// CHECK-LABEL: func @float_mixed_precision_generic_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT:     arith.extf
+// CHECK:         vector.contract
+// CHECK:         vector.transfer_write
+func.func @float_mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
+                         %C: memref<8x32xf32>) {
+  linalg.generic {
+    indexing_maps = [
+    affine_map<(m, n, k) -> (m, k)>,
+    affine_map<(m, n, k) -> (k, n)>,
+    affine_map<(m, n, k) -> (m, n)>
+    ],
+    iterator_types = ["parallel", "parallel", "reduction"]
+  }
+   ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
+   outs(%C : memref<8x32xf32>) {
+    ^bb(%in: bf16, %in_0: bf16, %c: f32) :
+      %a = arith.extf %in : bf16 to f32
+      %b = arith.extf %in_0 : bf16 to f32
+      %d = arith.mulf %a, %b: f32
+      %e = arith.addf %c, %d: f32
+      linalg.yield %e : f32
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1  { fold_mixed_precision_into_contract, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @integer_mixed_precision_generic_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT:     arith.extsi
+// CHECK:         vector.contract
+// CHECK:         vector.transfer_write
+func.func @integer_mixed_precision_generic_as_contract(%A: memref<8x16xi8>, %B: memref<16x32xi8>,
+                         %C: memref<8x32xi32>) {
+  linalg.generic {
+    indexing_maps = [
+    affine_map<(m, n, k) -> (m, k)>,
+    affine_map<(m, n, k) -> (k, n)>,
+    affine_map<(m, n, k) -> (m, n)>
+    ],
+    iterator_types = ["parallel", "parallel", "reduction"]
+  }
+   ins(%A, %B : memref<8x16xi8>, memref<16x32xi8>)
+   outs(%C : memref<8x32xi32>) {
+    ^bb(%in: i8, %in_0: i8, %c: i32) :
+      %a = arith.extsi %in : i8 to i32
+      %b = arith.extsi %in_0 : i8 to i32
+      %d = arith.muli %a, %b: i32
+      %e = arith.addi %c, %d: i32
+      linalg.yield %e : i32
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1  { fold_mixed_precision_into_contract, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @float_mixed_precision_matmul_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT:     arith.extf
+// CHECK:         vector.contract
+// CHECK:         vector.transfer_write
+func.func @float_mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
+                              %B: tensor<12x25xbf16>,
+                              %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  %0 = linalg.contract
+      indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+                       affine_map<(m, n, k) -> (k, n)>,
+                       affine_map<(m, n, k) -> (m, n)>]
+      ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
+      outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @integer_mixed_precision_matmul_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT:     arith.extf
+// CHECK:         vector.contract
+// CHECK:         vector.transfer_write
+func.func @integer_mixed_precision_matmul_as_contract(%A: tensor<24x12xi8>,
----------------
adam-smnk wrote:

nit: maybe this variant could be dropped altogether - 2 generics with int and float, 1 linalg.contract, and 1 linalg.matmul should be enough for completness

https://github.com/llvm/llvm-project/pull/148684


More information about the Mlir-commits mailing list