[Mlir-commits] [mlir] 4fbb5f9 - [mlir] introduce transform.loop.forall_to_for (#65474)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 20 09:53:12 PDT 2023


Author: Oleksandr "Alex" Zinenko
Date: 2023-09-20T18:53:08+02:00
New Revision: 4fbb5f93506317f063753c2b4aecebc86d63264e

URL: https://github.com/llvm/llvm-project/commit/4fbb5f93506317f063753c2b4aecebc86d63264e
DIFF: https://github.com/llvm/llvm-project/commit/4fbb5f93506317f063753c2b4aecebc86d63264e.diff

LOG: [mlir] introduce transform.loop.forall_to_for (#65474)

Add a straightforward sequentialization transform from `scf.forall` to a
nest of `scf.for` in absence of results and expose it as a transform op.
This is helpful in combination with other transform ops, particularly
fusion, that work best on parallel-by-construction `scf.forall` but
later need to target sequential `for` loops.

Added: 
    mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
    mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
index 26cc9b16cd9ca74..d14d63e56dc764f 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
@@ -20,6 +20,7 @@ namespace func {
 class FuncOp;
 } // namespace func
 namespace scf {
+class ForallOp;
 class ForOp;
 class IfOp;
 } // namespace scf

diff  --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 6f48b005bfcbf6e..207a004c54ef5af 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -40,6 +40,34 @@ def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
 
 def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
 
+def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Converts scf.forall into a nest of scf.for operations";
+  let description = [{
+    Converts the `scf.forall` operation pointed to by the given handle into a
+    set of nested `scf.for` operations. Each new operation corresponds to one
+    induction variable of the original "multifor" loop.
+
+    The operand handle must be associated with exactly one payload operation.
+
+    Loops with shared outputs are currently not supported.
+
+    #### Return Modes
+
+    Consumes the operand handle. Produces a silenceable failure if the operand
+    is not associated with a single `scf.forall` payload operation.
+    Returns as many handles as the given `forall` op has induction variables
+    that are associated with the generated `scf.for` loops.
+    Produces a silenceable failure if another number of resulting handles is
+    requested.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
 def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
     [NavigationTransformOpTrait, MemoryEffectsOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 88ddd22eea46b35..d7e8c38478ced1a 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -9,6 +9,8 @@
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -17,8 +19,11 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/OpDefinition.h"
 
 using namespace mlir;
 using namespace mlir::affine;
@@ -47,6 +52,7 @@ void transform::ApplySCFStructuralConversionPatternsOp::
 //===----------------------------------------------------------------------===//
 // GetParentForOp
 //===----------------------------------------------------------------------===//
+
 DiagnosedSilenceableFailure
 transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
                                  transform::TransformResults &results,
@@ -76,6 +82,72 @@ transform::GetParentForOp::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ForallToForOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
+                                transform::TransformResults &results,
+                                transform::TransformState &state) {
+  auto payload = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(payload))
+    return emitSilenceableError() << "expected a single payload op";
+
+  auto target = dyn_cast<scf::ForallOp>(*payload.begin());
+  if (!target) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "expected the payload to be scf.forall";
+    diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+    return diag;
+  }
+
+  rewriter.setInsertionPoint(target);
+
+  if (!target.getOutputs().empty()) {
+    return emitSilenceableError()
+           << "unsupported shared outputs (didn't bufferize?)";
+  }
+
+  SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
+  SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
+  SmallVector<OpFoldResult> steps = target.getMixedStep();
+
+  if (getNumResults() != lbs.size()) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError()
+        << "op expects as many results (" << getNumResults()
+        << ") as payload has induction variables (" << lbs.size() << ")";
+    diag.attachNote(target.getLoc()) << "payload op";
+    return diag;
+  }
+
+  auto loc = target.getLoc();
+  SmallVector<Value> ivs;
+  for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
+    Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
+    Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
+    Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
+    auto loop = rewriter.create<scf::ForOp>(
+        loc, lbValue, ubValue, stepValue, ValueRange(),
+        [](OpBuilder &, Location, Value, ValueRange) {});
+    ivs.push_back(loop.getInductionVar());
+    rewriter.setInsertionPointToStart(loop.getBody());
+    rewriter.create<scf::YieldOp>(loc);
+    rewriter.setInsertionPointToStart(loop.getBody());
+  }
+  rewriter.eraseOp(target.getBody()->getTerminator());
+  rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
+                             ivs);
+  rewriter.eraseOp(target);
+
+  for (auto &&[i, iv] : llvm::enumerate(ivs)) {
+    results.set(cast<OpResult>(getTransformed()[i]),
+                {iv.getParentBlock()->getParentOp()});
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopOutlineOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir b/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir
new file mode 100644
index 000000000000000..4b46c68d06d3514
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
+  // CHECK:   scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]])
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.loop.forall_to_for %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{expected a single payload op}}
+  transform.loop.forall_to_for %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%ub1: index, %ub2: index) {
+  // expected-note @below {{payload op}}
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{op expects as many results (1) as payload has induction variables (2)}}
+  transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{expected the payload to be scf.forall}}
+  transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
+}


        


More information about the Mlir-commits mailing list