[Mlir-commits] [mlir] [mlir] Allow unroll & jam on SCF loops with results (PR #98887)

Javier Setoain llvmlistbot at llvm.org
Tue Jul 16 02:22:24 PDT 2024


https://github.com/jsetoain updated https://github.com/llvm/llvm-project/pull/98887

>From da9ffdfb234f837a5af8dd894b5cb851de838a16 Mon Sep 17 00:00:00 2001
From: Javier Setoain <javier.setoain at gmail.com>
Date: Mon, 15 Jul 2024 05:18:18 -0600
Subject: [PATCH] [mlir] Allow unroll & jam on SCF loops with results

Unlike the affine version, the unroll & jam version for SCF loops does
not support loops with results/iter_args, but there's not real reason
to have this difference between the two.

Even though `iter_args` may indicate a loop-carried dependency and,
therefore, its unsuitability for unroll & jam, there are many
transformations that materialize loops with "transient" `iter_args` that
don't represent real dependencies, and will eventually go away. E.g.:
`linalg::tileLinalgOp` on non-bufferized linalg ops.

Given that this transformation doesn't perform a full dependency
analysis to ensure its safety, it's already up to the user to make sure
that the loop is parallel before proceeding. Allowing loops with
results makes this transformation more widely applicable without really
losing on safety.
---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |  21 ++--
 .../Dialect/SCF/transform-ops-invalid.mlir    |  23 ----
 mlir/test/Dialect/SCF/transform-ops.mlir      | 113 ++++++++++++++++++
 3 files changed, 123 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index c0ee9d2afe91c..d76707bc194ef 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -486,20 +486,25 @@ LogicalResult mlir::loopUnrollByFactor(
 }
 
 /// Check if bounds of all inner loops are defined outside of `forOp`
-/// and return false if not.
+/// or defined by constants, and return false if not.
 static bool areInnerBoundsInvariant(scf::ForOp forOp) {
   auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
-    if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
-        !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
-        !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+    auto isValueInvariantInLoop = [&forOp](Value v) -> bool {
+      return forOp.isDefinedOutsideOfLoop(v) ||
+             isa<arith::ConstantOp>(v.getDefiningOp());
+    };
+    if (!isValueInvariantInLoop(innerForOp.getLowerBound()) ||
+        !isValueInvariantInLoop(innerForOp.getUpperBound()) ||
+        !isValueInvariantInLoop(innerForOp.getStep()))
       return WalkResult::interrupt();
-
     return WalkResult::advance();
   });
   return !walkResult.wasInterrupted();
 }
 
 /// Unrolls and jams this loop by the specified factor.
+/// This function doesn't verify that the loop is parallel, if there are true
+/// loop carried dependencies, this function will produce invalid code.
 LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
                                           uint64_t unrollJamFactor) {
   assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
@@ -514,12 +519,6 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
     return failure();
   }
 
-  // Currently, for operations with results are not supported.
-  if (forOp->getNumResults() > 0) {
-    LDBG("failed to unroll and jam: unsupported loop with results");
-    return failure();
-  }
-
   // Currently, only constant trip count that divided by the unroll factor is
   // supported.
   std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
index 742b8a2861839..42f7761874a99 100644
--- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -63,29 +63,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @loop_unroll_and_jam_unsupported_loop_with_results() -> index {
-  %c0 = arith.constant 0 : index
-  %c40 = arith.constant 40 : index
-  %c2 = arith.constant 2 : index
-  %sum = scf.for %i = %c0 to %c40 step %c2 iter_args(%does_not_alias_aggregated = %c0) -> (index) {
-    %sum = arith.addi %i, %i : index
-    scf.yield %sum : index
-  }
-  return %sum : index
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
-    // expected-error @below {{failed to unroll and jam}}
-    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
-    transform.yield
-  }
-}
-
-// -----
-
 func.func private @loop_unroll_and_jam_unsupported_dynamic_trip_count(%arg0: memref<96x128xi8, 3>, %arg1: memref<128xi8, 3>) {
   %c96 = arith.constant 96 : index
   %c1 = arith.constant 1 : index
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index d9445182769e7..d01efcdae5f19 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -336,6 +336,119 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @loop_unroll_and_jam_loop_with_results
+func.func @loop_unroll_and_jam_loop_with_results() -> index {
+  // CHECK:           %[[C0:.*]] = arith.constant 0
+  // CHECK:           %[[UB:.*]] = arith.constant 40
+  // CHECK:           %[[STEP:.*]] = arith.constant 8
+  %c0 = arith.constant 0 : index
+  %c40 = arith.constant 40 : index
+  %c2 = arith.constant 2 : index
+  // CHECK:           %[[RES:.*]]:4 = scf.for %[[I:.*]] = %[[C0]] to %[[UB]] step %[[STEP]]
+  // CHECK-SAME:       iter_args(%[[ARG0:.*]] = %[[C0]], %[[ARG1:.*]] = %[[C0]],
+  // CHECK-SAME                  %[[ARG2:.*]] = %[[C0]], %[[ARG3:.*]] = %[[C0]])
+  %sum = scf.for %i = %c0 to %c40 step %c2 iter_args(%does_not_alias_aggregated = %c0) -> (index) {
+    %sum = arith.addi %i, %i : index
+    // CHECK:         scf.yield %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : index, index, index, index
+    scf.yield %sum : index
+  }
+  return %sum : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @unroll_jam_tiled_loops
+func.func @unroll_jam_tiled_loops(%A : tensor<8x16x4x8xbf16>, %B : tensor<16x8x8x4xbf16>) -> tensor<16x16x4x4xf32> {
+  // CHECK:      %[[C2:.*]] = arith.constant 2 : index
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C1:.*]] = arith.constant 1 : index
+  // CHECK:      %[[C8:.*]] = arith.constant 8 : index
+  // CHECK:      %[[C16:.*]] = arith.constant 16 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c8 = arith.constant 8 : index
+  %c16 = arith.constant 16 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %buf = memref.alloc() : memref<16x16x4x4xf32>
+  %ten = bufferization.to_tensor %buf restrict writable : memref<16x16x4x4xf32>
+  // CHECK:      %[[AT:.*]] = linalg.fill {{.*}} -> tensor<16x16x4x4xf32>
+  %acc = linalg.fill ins(%c0_f32 : f32) outs(%ten : tensor<16x16x4x4xf32>) -> tensor<16x16x4x4xf32>
+  // CHECK:      %[[L0R:.*]]:2 = scf.for %{{.*}} = %[[C0]] to %[[C16]] step %[[C2]]
+  // CHECK-SAME:     iter_args(%[[L0IA0:.*]] = %[[AT]], %[[L0IA1:.*]] = %[[AT]])
+  %l0r = scf.for %i = %c0 to %c16 step %c1 iter_args(%acl0l1 = %acc) -> (tensor<16x16x4x4xf32>) {
+    // CHECK:        %[[L1R:.*]]:4 = scf.for %{{.*}} = %[[C0]] to %[[C16]] step %[[C2]]
+    // CHECK-SAME:       iter_args(%[[L1IA0:.*]] = %[[L0IA0]], %[[L1IA1:.*]] = %[[L0IA0]],
+    // CHECK-SAME:                 %[[L1IA2:.*]] = %[[L0IA1]], %[[L1IA3:.*]] = %[[L0IA1]])
+    %l1r = scf.for %j = %c0 to %c16 step %c1 iter_args(%acl1l2 = %acl0l1) -> (tensor<16x16x4x4xf32>) {
+      // CHECK:          %[[L2R:.*]]:4 = scf.for %{{.*}} = %[[C0]] to %[[C8]] step %[[C1]]
+      // CHECK-SAME:         iter_args(%[[L2IA0:.*]] = %[[L1IA0]], %[[L2IA1:.*]] = %[[L1IA1]],
+      // CHECK-SAME:                   %[[L2IA2:.*]] = %[[L1IA2]], %[[L2IA3:.*]] = %[[L1IA3]])
+      %l2r = scf.for %k = %c0 to %c8 step %c1 iter_args(%C = %acl1l2) -> (tensor<16x16x4x4xf32>) {
+        %ta = tensor.extract_slice %A[%k, %i, 0, 0] [1, 1, 4, 8] [1, 1, 1, 1] : tensor<8x16x4x8xbf16> to tensor<1x1x4x8xbf16>
+        %tb = tensor.extract_slice %B[%j, %k, 0, 0] [1, 1, 8, 4] [1, 1, 1, 1] : tensor<16x8x8x4xbf16> to tensor<1x1x8x4xbf16>
+        %tc = tensor.extract_slice %C[%j, %i, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : tensor<16x16x4x4xf32> to tensor<1x1x4x4xf32>
+        %rr = linalg.generic {
+                    indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>,
+                                     affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>,
+                                     affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>],
+                    iterator_types = ["parallel", "parallel", "reduction",
+                                      "parallel", "parallel", "reduction"]}
+                ins(%ta, %tb : tensor<1x1x4x8xbf16>, tensor<1x1x8x4xbf16>)
+                outs(%tc : tensor<1x1x4x4xf32>) {
+              ^bb0(%ia: bf16, %ib: bf16, %out: f32):
+                %0 = arith.extf %ia : bf16 to f32
+                %1 = arith.extf %ib : bf16 to f32
+                %2 = arith.mulf %0, %1 : f32
+                %3 = arith.addf %out, %2 : f32
+                linalg.yield %3 : f32
+        } -> tensor<1x1x4x4xf32>
+        %is = tensor.insert_slice %rr into %C[%j, %i, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : tensor<1x1x4x4xf32> into tensor<16x16x4x4xf32>
+        // CHECK:            scf.yield %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
+        // CHECK-SAME:           tensor<16x16x4x4xf32>, tensor<16x16x4x4xf32>,
+        // CHECK-SAME:           tensor<16x16x4x4xf32>, tensor<16x16x4x4xf32>
+        scf.yield %is : tensor<16x16x4x4xf32>
+      }
+      // CHECK:          scf.yield %[[L2R]]#0, %[[L2R]]#1, %[[L2R]]#2, %[[L2R]]#3 :
+      // CHECK-SAME:         tensor<16x16x4x4xf32>, tensor<16x16x4x4xf32>,
+      // CHECK-SAME:         tensor<16x16x4x4xf32>, tensor<16x16x4x4xf32>
+      scf.yield %l2r : tensor<16x16x4x4xf32>
+    }
+    // CHECK:        scf.yield %[[L1R]]#0, %[[L1R]]#2 :
+    // CHECK-SAME:       tensor<16x16x4x4xf32>, tensor<16x16x4x4xf32>
+    scf.yield %l1r : tensor<16x16x4x4xf32>
+  }
+  // CHECK:      return %[[L0R]]#0 : tensor<16x16x4x4xf32>
+  return %l0r : tensor<16x16x4x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %l2 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    %l1 = transform.get_parent_op %l2 {op_name = "scf.for"} : (!transform.op<"scf.for">) -> !transform.op<"scf.for">
+    %l0 = transform.get_parent_op %l1 {op_name = "scf.for"} : (!transform.op<"scf.for">) -> !transform.op<"scf.for">
+    %mf = transform.get_parent_op %l0 {op_name = "func.func"} : (!transform.op<"scf.for">) -> !transform.any_op
+    transform.loop.unroll_and_jam %l1 {factor = 2 : i64} : !transform.op<"scf.for">
+    transform.loop.unroll_and_jam %l0 {factor = 2 : i64} : !transform.op<"scf.for">
+    transform.apply_patterns to %mf {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+    transform.apply_cse to %mf : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @loop_unroll_and_jam_op
 // CHECK:  %[[VAL_0:.*]]: memref<21x30xf32, 1>, %[[INIT0:.*]]: f32, %[[INIT1:.*]]: f32) {
 func.func @loop_unroll_and_jam_op(%arg0: memref<21x30xf32, 1>, %init : f32, %init1 : f32) {



More information about the Mlir-commits mailing list