[Mlir-commits] [mlir] [mlir] introduce transform.loop.forall_to_for (PR #65474)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Wed Sep 20 06:55:46 PDT 2023
https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/65474
>From 970d81c8a1da6404c561086079c03fbdb218c226 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 5 Sep 2023 16:04:52 +0000
Subject: [PATCH] [mlir] introduce transform.loop.forall_to_for
Add a straightforward sequentialization transform from `scf.forall` to a
nest of `scf.for` in absence of results and expose it as a transform op.
This is helpful in combination with other transform ops, particularly
fusion, that work best on parallel-by-construction `scf.forall` but
later need to target sequential `for` loops.
---
.../SCF/TransformOps/SCFTransformOps.h | 1 +
.../SCF/TransformOps/SCFTransformOps.td | 28 +++++++
.../SCF/TransformOps/SCFTransformOps.cpp | 72 ++++++++++++++++++
.../SCF/transform-op-forall-to-for.mlir | 73 +++++++++++++++++++
4 files changed, 174 insertions(+)
create mode 100644 mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
index 26cc9b16cd9ca74..d14d63e56dc764f 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
@@ -20,6 +20,7 @@ namespace func {
class FuncOp;
} // namespace func
namespace scf {
+class ForallOp;
class ForOp;
class IfOp;
} // namespace scf
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 6f48b005bfcbf6e..207a004c54ef5af 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -40,6 +40,34 @@ def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
+def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Converts scf.forall into a nest of scf.for operations";
+ let description = [{
+ Converts the `scf.forall` operation pointed to by the given handle into a
+ set of nested `scf.for` operations. Each new operation corresponds to one
+ induction variable of the original "multifor" loop.
+
+ The operand handle must be associated with exactly one payload operation.
+
+ Loops with shared outputs are currently not supported.
+
+ #### Return Modes
+
+ Consumes the operand handle. Produces a silenceable failure if the operand
+ is not associated with a single `scf.forall` payload operation.
+ Returns as many handles as the given `forall` op has induction variables
+ that are associated with the generated `scf.for` loops.
+ Produces a silenceable failure if another number of resulting handles is
+ requested.
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
[NavigationTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 88ddd22eea46b35..d7e8c38478ced1a 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -9,6 +9,8 @@
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -17,8 +19,11 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/OpDefinition.h"
using namespace mlir;
using namespace mlir::affine;
@@ -47,6 +52,7 @@ void transform::ApplySCFStructuralConversionPatternsOp::
//===----------------------------------------------------------------------===//
// GetParentForOp
//===----------------------------------------------------------------------===//
+
DiagnosedSilenceableFailure
transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@@ -76,6 +82,72 @@ transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// ForallToForOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto payload = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(payload))
+ return emitSilenceableError() << "expected a single payload op";
+
+ auto target = dyn_cast<scf::ForallOp>(*payload.begin());
+ if (!target) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "expected the payload to be scf.forall";
+ diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+ return diag;
+ }
+
+ rewriter.setInsertionPoint(target);
+
+ if (!target.getOutputs().empty()) {
+ return emitSilenceableError()
+ << "unsupported shared outputs (didn't bufferize?)";
+ }
+
+ SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
+ SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
+ SmallVector<OpFoldResult> steps = target.getMixedStep();
+
+ if (getNumResults() != lbs.size()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "op expects as many results (" << getNumResults()
+ << ") as payload has induction variables (" << lbs.size() << ")";
+ diag.attachNote(target.getLoc()) << "payload op";
+ return diag;
+ }
+
+ auto loc = target.getLoc();
+ SmallVector<Value> ivs;
+ for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
+ Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
+ Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
+ Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
+ auto loop = rewriter.create<scf::ForOp>(
+ loc, lbValue, ubValue, stepValue, ValueRange(),
+ [](OpBuilder &, Location, Value, ValueRange) {});
+ ivs.push_back(loop.getInductionVar());
+ rewriter.setInsertionPointToStart(loop.getBody());
+ rewriter.create<scf::YieldOp>(loc);
+ rewriter.setInsertionPointToStart(loop.getBody());
+ }
+ rewriter.eraseOp(target.getBody()->getTerminator());
+ rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
+ ivs);
+ rewriter.eraseOp(target);
+
+ for (auto &&[i, iv] : llvm::enumerate(ivs)) {
+ results.set(cast<OpResult>(getTransformed()[i]),
+ {iv.getParentBlock()->getParentOp()});
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// LoopOutlineOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir b/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir
new file mode 100644
index 000000000000000..4b46c68d06d3514
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+ // CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]])
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.loop.forall_to_for %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{expected a single payload op}}
+ transform.loop.forall_to_for %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%ub1: index, %ub2: index) {
+ // expected-note @below {{payload op}}
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{op expects as many results (1) as payload has induction variables (2)}}
+ transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{expected the payload to be scf.forall}}
+ transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
+}
More information about the Mlir-commits
mailing list