[Mlir-commits] [mlir] [mlir][vector] Propagate `vector.extract` through elementwise ops (PR #131462)

Ivan Butygin llvmlistbot at llvm.org
Sat Mar 22 16:17:13 PDT 2025


================
@@ -423,3 +423,76 @@ func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %
   %r = arith.addf %at, %bt : vector<6x[4]x2x3xf32>
   return %r : vector<6x[4]x2x3xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @extract_elementwise_scalar
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK:   return %[[RES]] : f32
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_elementwise_arg_res_different_types
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xindex>)
+func.func @extract_elementwise_arg_res_different_types(%arg0: vector<4xindex>) -> i64 {
+// CHECK:   %[[EXT:.*]] = vector.extract %[[ARG0]][1] : index from vector<4xindex>
+// CHECK:   %[[RES:.*]] = arith.index_cast %[[EXT]] : index to i64
+// CHECK:   return %[[RES]] : i64
+  %0 = arith.index_cast %arg0: vector<4xindex> to vector<4xi64>
+  %1 = vector.extract %0[1] : i64 from vector<4xi64>
+  return %1 : i64
+}
+
+// CHECK-LABEL: @extract_elementwise_vec
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_elementwise_vec(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_elementwise_no_single_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
----------------
Hardcode84 wrote:

done

https://github.com/llvm/llvm-project/pull/131462


More information about the Mlir-commits mailing list