[Mlir-commits] [mlir] [mlir] introduce transform.loop.forall_to_for (PR #65474)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Wed Sep 6 06:20:04 PDT 2023


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/65474:

Add a straightforward sequentialization transform from `scf.forall` to a nest of `scf.for` in absence of results and expose it as a transform op. This is helpful in combination with other transform ops, particularly fusion, that work best on parallel-by-construction `scf.forall` but later need to target sequential `for` loops.

>From f450b35c7b325a8d677b5b5eb08c73ba75e27bd9 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 5 Sep 2023 16:04:52 +0000
Subject: [PATCH] [mlir] introduce transform.loop.forall_to_for

Add a straightforward sequentialization transform from `scf.forall` to a
nest of `scf.for` in absence of results and expose it as a transform op.
This is helpful in combination with other transform ops, particularly
fusion, that work best on parallel-by-construction `scf.forall` but
later need to target sequential `for` loops.
---
 .../SCF/TransformOps/SCFTransformOps.h        |  1 +
 .../SCF/TransformOps/SCFTransformOps.td       | 27 +++++++
 .../SCF/TransformOps/SCFTransformOps.cpp      | 78 +++++++++++++++++++
 .../SCF/transform-op-forall-to-for.mlir       | 73 +++++++++++++++++
 4 files changed, 179 insertions(+)
 create mode 100644 mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
index 26cc9b16cd9ca7..d14d63e56dc764 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
@@ -20,6 +20,7 @@ namespace func {
 class FuncOp;
 } // namespace func
 namespace scf {
+class ForallOp;
 class ForOp;
 class IfOp;
 } // namespace scf
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 3efc047ff0786d..3d6799b7e24049 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -29,6 +29,33 @@ def ApplyForLoopCanonicalizationPatternsOp : Op<Transform_Dialect,
 
 def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
 
+def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Converts scf.forall into a nest of scf.for operations";
+  let description = [{
+    Converts the `scf.forall` operation pointed to by the given handle into a
+    set of nested `scf.for` operations. Each new operation corresponds to one
+    induction variable of the original "multifor" loop.
+
+    The operand handle must be associated with exactly one payload operation.
+
+    Loops with returning values are currently not supported.
+
+    #### Return Modes
+
+    Consumes the operand handle. Produces a silenceable failure if the operand
+    is not associated with a single `scf.forall` payload operation.
+    Returns as many handles as the given `forall` op has induction variables.
+    Produces a silenceable failure if another number of resulting handles is
+    requested.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
 def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
     [NavigationTransformOpTrait, MemoryEffectsOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 5b8dd2c68b84e5..a069a6c638a145 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -17,8 +18,11 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/OpDefinition.h"
 
 using namespace mlir;
 using namespace mlir::affine;
@@ -35,6 +39,7 @@ void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
 //===----------------------------------------------------------------------===//
 // GetParentForOp
 //===----------------------------------------------------------------------===//
+
 DiagnosedSilenceableFailure
 transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
                                  transform::TransformResults &results,
@@ -64,6 +69,79 @@ transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ForallToForOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
+                                transform::TransformResults &results,
+                                transform::TransformState &state) {
+  auto payload = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(payload)) {
+    return emitSilenceableError() << "expected a single payload op";
+  }
+  auto target = dyn_cast<scf::ForallOp>(*payload.begin());
+  if (!target) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "expected the payload to be scf.forall";
+    diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+    return diag;
+  }
+
+  rewriter.setInsertionPoint(target);
+
+  if (!target.getOutputs().empty()) {
+    return emitSilenceableError()
+           << "unsupported shared outputs (didn't bufferize?)";
+  }
+
+  auto materialize = [](OpBuilder &b, Location loc, OpFoldResult r) -> Value {
+    if (Value v = r.dyn_cast<Value>())
+      return v;
+    return b.create<arith::ConstantIndexOp>(
+        loc, r.get<Attribute>().cast<IntegerAttr>().getValue().getSExtValue());
+  };
+
+  SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
+  SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
+  SmallVector<OpFoldResult> steps = target.getMixedStep();
+
+  if (getNumResults() != lbs.size()) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError()
+        << "op expects as many results (" << getNumResults()
+        << ") as payload has induction variables (" << lbs.size() << ")";
+    diag.attachNote(target.getLoc()) << "payload op";
+    return diag;
+  }
+
+  auto loc = target.getLoc();
+  SmallVector<Value> ivs;
+  for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
+    Value lbValue = materialize(rewriter, loc, lb);
+    Value ubValue = materialize(rewriter, loc, ub);
+    Value stepValue = materialize(rewriter, loc, step);
+    auto loop = rewriter.create<scf::ForOp>(
+        loc, lbValue, ubValue, stepValue, ValueRange(),
+        [](OpBuilder &, Location, Value, ValueRange) {});
+    ivs.push_back(loop.getInductionVar());
+    rewriter.setInsertionPointToStart(loop.getBody());
+    rewriter.create<scf::YieldOp>(loc);
+    rewriter.setInsertionPointToStart(loop.getBody());
+  }
+  rewriter.eraseOp(target.getBody()->getTerminator());
+  rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
+                             ivs);
+  rewriter.eraseOp(target);
+
+  for (auto &&[i, iv] : llvm::enumerate(ivs)) {
+    results.set(cast<OpResult>(getTransformed()[i]),
+                {iv.getParentBlock()->getParentOp()});
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopOutlineOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir b/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir
new file mode 100644
index 00000000000000..4b46c68d06d351
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]])
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.loop.forall_to_for %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{expected a single payload op}}
+  transform.loop.forall_to_for %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%ub1: index, %ub2: index) {
+  // expected-note @below {{payload op}}
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{op expects as many results (1) as payload has induction variables (2)}}
+  transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{expected the payload to be scf.forall}}
+  transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
+}



More information about the Mlir-commits mailing list