[Mlir-commits] [mlir] [mlir][Hoisting] Hoisting vector.extract/vector.broadcast pairs (PR #86108)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Thu Apr 18 04:32:59 PDT 2024
================
@@ -565,3 +565,112 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs
+
+// CHECK-LABEL: func.func @hoist_vector_broadcasts
+// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
+// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
+// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
+// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
+ %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
+ %extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
+ %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
+ %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
+ scf.yield %broadcast : vector<3x4xf32>
+ }
+ return %bcast_vec : vector<3x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_broadcasts %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
+
+// CHECK-LABEL: func.func @hoist_vector_broadcasts
+// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
+// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
+// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
+// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
+ %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
+ %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
+ %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
+ %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
+ scf.yield %broadcast : vector<3x4xf32>
+ }
+ return %bcast_vec : vector<3x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_broadcasts %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
+
+// CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple
+// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
+// CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
+// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
+// CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
+// CHECK-NEXT: %[[LOOP:.+]]:2 = scf.for {{.*}} {
+// CHECK-DAG: %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK-DAG: %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32>
+// CHECK-NEXT: scf.yield %[[USE1]], %[[USE2]] : vector<4xf32>, vector<5xf32>
+// CHECK-NEXT: }
+// CHECK-DAG: %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32>
+// CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
+// CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
+
+func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
+ %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
+ %extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
+ %extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>
+ %use1 = "some_use1"(%extract1) : (vector<4xf32>) -> vector<4xf32>
+ %use2 = "some_use2"(%extract2) : (vector<5xf32>) -> vector<5xf32>
+ %broadcast1 = vector.broadcast %use1 : vector<4xf32> to vector<3x4xf32>
+ %broadcast2 = vector.broadcast %use2 : vector<5xf32> to vector<3x5xf32>
+ scf.yield %broadcast1, %broadcast2 : vector<3x4xf32>,vector<3x5xf32>
+ }
+ return %bcast_vec#0, %bcast_vec#1 : vector<3x4xf32>, vector<3x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_broadcasts %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
----------------
ftynse wrote:
Nit: trailing newline plz.
https://github.com/llvm/llvm-project/pull/86108
More information about the Mlir-commits
mailing list