[Mlir-commits] [mlir] [mlir][vector] Update `CombineContractBroadcastMask` (PR #140050)
Han-Chung Wang
llvmlistbot at llvm.org
Mon May 19 11:39:44 PDT 2025
================
@@ -116,6 +200,72 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1
return %result : vector<8x8xi32>
}
+// -----
+
+// Same as above, but with a mask.
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked_scalable
+// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<[8]x4xi32>, %[[ARG2:.+]]: vector<8x[8]xi32>, %[[MASK:.+]]: vector<1x8x[8]x4xi1>)
+// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x[8]x4xi1> to vector<8x[8]x4xi1>
+// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
+// CHECK-SAME: vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<[8]x4xi32> into vector<8x[8]xi32>
+func.func @contract_broadcast_unit_dim_reduction_masked_scalable(%arg0 : vector<8x4xi32>, %arg1 : vector<[8]x4xi32>, %arg2 : vector<8x[8]xi32>, %mask: vector<1x8x[8]x4xi1>) -> vector<8x[8]xi32> {
+ %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
+ %1 = vector.broadcast %arg1 : vector<[8]x4xi32> to vector<1x[8]x4xi32>
+ %result = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x[8]x4xi32> into vector<8x[8]xi32>
+ } : vector<1x8x[8]x4xi1> -> vector<8x[8]xi32>
+ return %result : vector<8x[8]xi32>
+}
+
+// -----
+
+// Same as above, but with a scalable dim.
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked
+// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>, %[[MASK:.+]]: vector<1x8x8x4xi1>)
+// CHECK: %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1> to vector<8x8x4xi1>
+// CHECK: %[[R:.*]] = vector.mask %[[MASK_SC]] {
+// CHECK-SAME: vector.contract
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
+func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
+ %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
+ %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
+ %result = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
+ } : vector<1x8x8x4xi1> -> vector<8x8xi32>
+ return %result : vector<8x8xi32>
+}
----------------
hanhanW wrote:
I think you need to swap two test cases? The first one has scalable dim, and the second one is the case with a mask?
https://github.com/llvm/llvm-project/pull/140050
More information about the Mlir-commits
mailing list