[Mlir-commits] [mlir] [mlir][Vector] Improve support for vector.extract(broadcast) (PR #116234)
Kunwar Grover
llvmlistbot at llvm.org
Thu Feb 20 03:07:56 PST 2025
================
@@ -747,57 +775,51 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
// CHECK-LABEL: fold_extract_splat
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32) -> f32 {
+func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.splat %a : vector<1x2x4xf32>
- %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
}
// -----
-// CHECK-LABEL: fold_extract_broadcast_vector
-// CHECK-SAME: %[[A:.*]]: vector<4xf32>
-// CHECK: return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
- %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
- return %r : vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: fold_extract_broadcast
-// CHECK-SAME: %[[A:.*]]: vector<4xf32>
-// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
+// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
+// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
+// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
// CHECK: return %[[R]] : f32
-func.func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
- %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
+ %idx : index, %idx1 : index, %idx2 : index) -> f32 {
+ %b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
+ %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
}
// -----
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
-func.func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
+ -> vector<4xf32> {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
----------------
Groverkss wrote:
@fold_extract_broadcast_diff_input_output_vec --> rank(extract_output) > rank(broadcast_input)
@fold_extract_broadcast_to_higher_rank --> rank(extract_output) < rank(broadcast_input)
@fold_extract_broadcast_to_equal_rank --> rank(extract_output) = rank(broadcast_input)
I can change the naming to be more consistent and make it:
@fold_extract_broadcast_to_lower_rank --> rank(extract_output) > rank(broadcast_input)
@fold_extract_broadcast_to_higher_rank --> rank(extract_output) < rank(broadcast_input)
@fold_extract_broadcast_to_equal_rank --> rank(extract_output) = rank(broadcast_input)
and add a comment above each of the tests
https://github.com/llvm/llvm-project/pull/116234
More information about the Mlir-commits
mailing list