[Mlir-commits] [mlir] [mlir] Allow unroll & jam on SCF loops with results (PR #98887)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 15 12:07:54 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Javier Setoain (jsetoain)
<details>
<summary>Changes</summary>
Unlike the affine version, the unroll & jam version for SCF loops does not support loops with results/iter_args, but there's not real reason to have this difference between the two.
Even though `iter_args` may indicate a loop-carried dependency and, therefore, its unsuitability for unroll & jam, there are many transformations that materialize loops with "transient" `iter_args` that don't represent real dependencies, and will eventually go away. E.g.: `linalg::tileLinalgOp` on non-bufferized linalg ops.
Given that this transformation doesn't perform a full dependency analysis to ensure its safety, it's already up to the user to make sure that the loop is parallel before proceeding. Allowing loops with results makes this transformation more widely applicable without really losing on safety.
---
Full diff: https://github.com/llvm/llvm-project/pull/98887.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+9-10)
- (modified) mlir/test/Dialect/SCF/transform-ops-invalid.mlir (-23)
- (modified) mlir/test/Dialect/SCF/transform-ops.mlir (+27)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index c0ee9d2afe91c..b390658009587 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -486,12 +486,15 @@ LogicalResult mlir::loopUnrollByFactor(
}
/// Check if bounds of all inner loops are defined outside of `forOp`
-/// and return false if not.
+/// or defined by constants, and return false if not.
static bool areInnerBoundsInvariant(scf::ForOp forOp) {
auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
- if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
- !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
- !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+ if (!(forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+ isa<arith::ConstantOp>(innerForOp.getLowerBound().getDefiningOp())) ||
+ !(forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+ isa<arith::ConstantOp>(innerForOp.getUpperBound().getDefiningOp())) ||
+ !(forOp.isDefinedOutsideOfLoop(innerForOp.getStep()) ||
+ isa<arith::ConstantOp>(innerForOp.getStep().getDefiningOp())))
return WalkResult::interrupt();
return WalkResult::advance();
@@ -500,6 +503,8 @@ static bool areInnerBoundsInvariant(scf::ForOp forOp) {
}
/// Unrolls and jams this loop by the specified factor.
+/// This function doesn't verify that the loop is parallel, if there are true
+/// loop carried dependencies, this function will produce invalid code.
LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
uint64_t unrollJamFactor) {
assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
@@ -514,12 +519,6 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
return failure();
}
- // Currently, for operations with results are not supported.
- if (forOp->getNumResults() > 0) {
- LDBG("failed to unroll and jam: unsupported loop with results");
- return failure();
- }
-
// Currently, only constant trip count that divided by the unroll factor is
// supported.
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
index 742b8a2861839..42f7761874a99 100644
--- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -63,29 +63,6 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @loop_unroll_and_jam_unsupported_loop_with_results() -> index {
- %c0 = arith.constant 0 : index
- %c40 = arith.constant 40 : index
- %c2 = arith.constant 2 : index
- %sum = scf.for %i = %c0 to %c40 step %c2 iter_args(%does_not_alias_aggregated = %c0) -> (index) {
- %sum = arith.addi %i, %i : index
- scf.yield %sum : index
- }
- return %sum : index
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
- // expected-error @below {{failed to unroll and jam}}
- transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
- transform.yield
- }
-}
-
-// -----
-
func.func private @loop_unroll_and_jam_unsupported_dynamic_trip_count(%arg0: memref<96x128xi8, 3>, %arg1: memref<128xi8, 3>) {
%c96 = arith.constant 96 : index
%c1 = arith.constant 1 : index
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index d9445182769e7..4351df5961704 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -336,6 +336,33 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @loop_unroll_and_jam_loop_with_results
+func.func @loop_unroll_and_jam_loop_with_results() -> index {
+ // CHECK: %[[LB:.*]] = arith.constant 0
+ // CHECK: %[[UB:.*]] = arith.constant 40
+ // CHECK: %[[STEP:.*]] = arith.constant 8
+ %c0 = arith.constant 0 : index
+ %c40 = arith.constant 40 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: %[[RES:[a-zA-Z0-9]+]]:4 = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]]
+ %sum = scf.for %i = %c0 to %c40 step %c2 iter_args(%does_not_alias_aggregated = %c0) -> (index) {
+ %sum = arith.addi %i, %i : index
+ scf.yield %sum : index
+ }
+ return %sum : index
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: @loop_unroll_and_jam_op
// CHECK: %[[VAL_0:.*]]: memref<21x30xf32, 1>, %[[INIT0:.*]]: f32, %[[INIT1:.*]]: f32) {
func.func @loop_unroll_and_jam_op(%arg0: memref<21x30xf32, 1>, %init : f32, %init1 : f32) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/98887
More information about the Mlir-commits
mailing list