[Mlir-commits] [mlir] [mlir][SCF] Avoid generating unnecessary div/rem operations during coalescing (PR #91562)
Quinn Dawkins
llvmlistbot at llvm.org
Thu May 9 14:52:30 PDT 2024
================
@@ -299,3 +299,80 @@ module attributes {transform.with_named_sequence} {
// CHECK-NOT: scf.for
// CHECK: transform.named_sequence
+// -----
+
+// Check avoiding generating unnecessary operations while collapsing trip-1 loops.
+func.func @trip_one_loops(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = scf.for %iv0 = %c0 to %c1 step %c1 iter_args(%iter0 = %arg0) -> tensor<?x?xf32> {
+ %1 = scf.for %iv1 = %c0 to %c1 step %c1 iter_args(%iter1 = %iter0) -> tensor<?x?xf32> {
+ %2 = scf.for %iv2 = %c0 to %arg1 step %c1 iter_args(%iter2 = %iter1) -> tensor<?x?xf32> {
+ %3 = scf.for %iv3 = %c0 to %c1 step %c1 iter_args(%iter3 = %iter2) -> tensor<?x?xf32> {
+ %4 = scf.for %iv4 = %c0 to %arg2 step %c1 iter_args(%iter4 = %iter3) -> tensor<?x?xf32> {
+ %5 = "some_use"(%iter4, %iv0, %iv1, %iv2, %iv3, %iv4)
+ : (tensor<?x?xf32>, index, index, index, index, index) -> (tensor<?x?xf32>)
+ scf.yield %5 : tensor<?x?xf32>
+ }
+ scf.yield %4 : tensor<?x?xf32>
+ }
+ scf.yield %3 : tensor<?x?xf32>
+ }
+ scf.yield %2 : tensor<?x?xf32>
+ }
+ scf.yield %1 : tensor<?x?xf32>
+ } {coalesce}
+ return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+ %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func @trip_one_loops
+// CHECK-SAME: , %[[ARG1:.+]]: index,
+// CHECK-SAME: %[[ARG2:.+]]: index)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[UB:.+]] = arith.muli %[[ARG1]], %[[ARG2]]
+// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[UB]] step %[[C1]]
+// CHECK: %[[IV1:.+]] = arith.remsi %[[IV]], %[[ARG2]]
+// CHECK: %[[IV2:.+]] = arith.divsi %[[IV]], %[[ARG2]]
+// CHECK: "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[IV2]], %[[C0]], %[[IV1]])
+
+// -----
+
+// Check generating no instructions when all except one loops is non unit-trip.
+func.func @all_outer_trip_one(%arg0 : tensor<?x?xf32>, %arg1 : index) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = scf.for %iv0 = %c0 to %c1 step %c1 iter_args(%iter0 = %arg0) -> tensor<?x?xf32> {
+ %1 = scf.for %iv1 = %c0 to %c1 step %c1 iter_args(%iter1 = %iter0) -> tensor<?x?xf32> {
+ %2 = scf.for %iv2 = %c0 to %arg1 step %c1 iter_args(%iter2 = %iter1) -> tensor<?x?xf32> {
+ %3 = "some_use"(%iter2, %iv0, %iv1, %iv2)
+ : (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>)
+ scf.yield %3 : tensor<?x?xf32>
+ }
+ scf.yield %2 : tensor<?x?xf32>
+ }
+ scf.yield %1 : tensor<?x?xf32>
+ } {coalesce}
+ return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+ %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func @all_outer_trip_one
+// CHECK-SAME: , %[[ARG1:.+]]: index)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[ARG1]] step %[[C1]]
+// CHECK: "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[IV]])
----------------
qedawkins wrote:
SGTM, I was mainly asking to check if the pattern is correct in this case (tied with the earlier question)
https://github.com/llvm/llvm-project/pull/91562
More information about the Mlir-commits
mailing list