[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)

Adam Siemieniuk llvmlistbot at llvm.org
Wed Dec 17 02:31:55 PST 2025


================
@@ -385,6 +385,183 @@ func.func @negative_no_memref_src(
 // CHECK: vector.contract
 // CHECK-NOT: vector.fma
 
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1x2xbf16>
+!memrefB = memref<1x32x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_vnni_offset_1(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = ub.poison : bf16
+  %1 = vector.load %arg0[%c0, %c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0, %c1] :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @negative_vnni_offset_1
+// CHECK: vector.contract
+// CHECK-NOT: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1x2xbf16>
+!memrefB = memref<1x32x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+#perm0 = affine_map<(d1, d2, d3) -> (d2, d1, d3)>
+func.func @negative_perm_map_not_identical(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {permutation_map = #perm0,
+        in_bounds = [true, true, true]} : !memrefA, !vecA
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefB, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @negative_perm_map_not_identical
+// CHECK: vector.contract
+// CHECK-NOT: vector.fma
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1x2xbf16>
+!memrefB = memref<1x32x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_non_unit_stride(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !vecC) -> !vecC
+{
+  %c0 = arith.constant 0 : index
+  %0 = ub.poison : bf16
+  %subview_1 = memref.subview %arg1[%c0, %c0, %c0] [1, 16, 2] [1, 1, 2] :
+               !memrefB to memref<1x16x2xbf16, strided<[64, 2, 2], offset: ?>>
+
+  %1 = vector.transfer_read %arg0[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %subview_1[%c0, %c0, %c0], %0 {in_bounds = [true, true, true]} :
+        memref<1x16x2xbf16, strided<[64, 2, 2], offset: ?>>, !vecB
+  %3 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %arg2
+    : !vecA, !vecB into !vecC
+  return %3 : !vecC
+}
+
+// CHECK-LABEL: @negative_non_unit_stride
+// CHECK: vector.contract
+// CHECK-NOT: vector.fma
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x2xbf16>
+!vecB = vector<1x8x2xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<4x1x2xbf16>
+!memrefB = memref<1x32x2xbf16>
+#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)>
+#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)>
+#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)>
+func.func @negative_false_bound(
----------------
adam-smnk wrote:

nit: maybe `negative_out_of_bounds`

https://github.com/llvm/llvm-project/pull/170267


More information about the Mlir-commits mailing list