[Mlir-commits] [mlir] [mlir][vector] Update `CombineContractBroadcastMask` (PR #140050)

Han-Chung Wang llvmlistbot at llvm.org
Mon May 19 11:39:44 PDT 2025


================
@@ -116,6 +200,72 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1
     return %result : vector<8x8xi32>
 }
 
+// -----
+
+// Same as above, but with a mask.
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked_scalable
+//  CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<[8]x4xi32>, %[[ARG2:.+]]: vector<8x[8]xi32>, %[[MASK:.+]]: vector<1x8x[8]x4xi1>)
+//  CHECK:      %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x[8]x4xi1> to vector<8x[8]x4xi1>
+//  CHECK:      %[[R:.*]] = vector.mask %[[MASK_SC]] {
+//  CHECK-SAME: vector.contract
+//  CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+//  CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+//  CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<[8]x4xi32> into vector<8x[8]xi32>
+func.func @contract_broadcast_unit_dim_reduction_masked_scalable(%arg0 : vector<8x4xi32>, %arg1 : vector<[8]x4xi32>, %arg2 : vector<8x[8]xi32>, %mask: vector<1x8x[8]x4xi1>) -> vector<8x[8]xi32> {
+    %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
+    %1 = vector.broadcast %arg1 : vector<[8]x4xi32> to vector<1x[8]x4xi32>
+    %result = vector.mask %mask {
+      vector.contract {
+        indexing_maps = [#map0, #map1, #map2],
+        iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+        kind = #vector.kind<add>
+      } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x[8]x4xi32> into vector<8x[8]xi32>
+    } : vector<1x8x[8]x4xi1> -> vector<8x[8]xi32>
+    return %result : vector<8x[8]xi32>
+}
+
+// -----
+
+// Same as above, but with a scalable dim.
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked
+//  CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>, %[[MASK:.+]]: vector<1x8x8x4xi1>)
+//  CHECK:      %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1> to vector<8x8x4xi1>
+//  CHECK:      %[[R:.*]] = vector.mask %[[MASK_SC]] {
+//  CHECK-SAME: vector.contract
+//  CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+//  CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+//  CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
+func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
+    %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
+    %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
+    %result = vector.mask %mask {
+      vector.contract {
+        indexing_maps = [#map0, #map1, #map2],
+        iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+        kind = #vector.kind<add>
+      } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
+    } : vector<1x8x8x4xi1> -> vector<8x8xi32>
+    return %result : vector<8x8xi32>
+}
----------------
hanhanW wrote:

I think you need to swap two test cases? The first one has scalable dim, and the second one is the case with a mask?

https://github.com/llvm/llvm-project/pull/140050


More information about the Mlir-commits mailing list