[Mlir-commits] [mlir] 9aaf007 - [SCF][Transform] Add transform.loop.fuse_sibling
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Aug 19 03:03:07 PDT 2023
Author: Groverkss
Date: 2023-08-19T15:24:23+05:30
New Revision: 9aaf007a982aa9cd0c639c09980962fd772c8eb8
URL: https://github.com/llvm/llvm-project/commit/9aaf007a982aa9cd0c639c09980962fd772c8eb8
DIFF: https://github.com/llvm/llvm-project/commit/9aaf007a982aa9cd0c639c09980962fd772c8eb8.diff
LOG: [SCF][Transform] Add transform.loop.fuse_sibling
This patch adds a new transform operation `transform.loop.fuse_sibling`,
which given two loops, fuses them, assuming that they are independent.
The transform operation itself performs very basic checks to ensure
IR legality, and leaves the responsibility of ensuring independence on the user.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D157069
Added:
mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
Modified:
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 55a378d96fe677..3efc047ff0786d 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -310,4 +310,39 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
}];
}
+def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Fuse a loop into another loop, assuming the fusion is legal.";
+
+ let description = [{
+ Fuses the `target` loop into the `source` loop assuming they are
+ independent of each other. It is the responsibility of the user to ensure
+ that the given two loops are independent of each other, this operation will
+ not performa any legality checks and will simply fuse the two given loops.
+
+ Currently, the only fusion supported is when both `target` and `source`
+ are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
+ mapping must match, otherwise a silencable failure is produced.
+
+ The input handles `target` and `source` must map to exactly one operation,
+ a definite failure is produced otherwise.
+
+ #### Return modes
+
+ This operation consumes the `target` and `source` handles and produces the
+ `fused_loop` handle, which points to the fused loop.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ TransformHandleTypeInterface:$source);
+ let results = (outs TransformHandleTypeInterface:$fused_loop);
+ let assemblyFormat = "$target `into` $source attr-dict "
+ " `:` functional-type(operands, results)";
+
+ let builders = [
+ OpBuilder<(ins "Value":$loop, "Value":$fused_loop)>
+ ];
+}
+
#endif // SCF_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 2e299fd357f282..bde30c9c3528db 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -185,6 +185,17 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
scf::ForOp root);
+/// Given two scf.forall loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
+ scf::ForallOp source,
+ RewriterBase &rewriter);
+
} // namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 0ca32723e4a38b..5b8dd2c68b84e5 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Dominance.h"
using namespace mlir;
using namespace mlir::affine;
@@ -318,6 +319,146 @@ void transform::TakeAssumedBranchOp::getEffects(
modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// LoopFuseSibling
+//===----------------------------------------------------------------------===//
+
+/// Check if `target` and `source` are siblings, in the context that `target`
+/// is being fused into `source`.
+///
+/// This is a simple check that just checks if both operations are in the same
+/// block and some checks to ensure that the fused IR does not violate
+/// dominance.
+static DiagnosedSilenceableFailure isOpSibling(Operation *target,
+ Operation *source) {
+ // Check if both operations are same.
+ if (target == source)
+ return emitSilenceableFailure(source)
+ << "target and source need to be
diff erent loops";
+
+ // Check if both operations are in the same block.
+ if (target->getBlock() != source->getBlock())
+ return emitSilenceableFailure(source)
+ << "target and source are not in the same block";
+
+ // Check if fusion will violate dominance.
+ DominanceInfo domInfo(source);
+ if (target->isBeforeInBlock(source)) {
+ // Since, `target` is before `source`, all users of results of `target`
+ // need to be dominated by `source`.
+ for (Operation *user : target->getUsers()) {
+ if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
+ return emitSilenceableFailure(target)
+ << "user of results of target should be properly dominated by "
+ "source";
+ }
+ }
+ } else {
+ // Since `target` is after `source`, all values used by `target` need
+ // to dominate `source`.
+
+ // Check if operands of `target` are dominated by `source`.
+ for (Value operand : target->getOperands()) {
+ Operation *operandOp = operand.getDefiningOp();
+ // If operand does not have a defining operation, it is a block arguement,
+ // which will always dominate `source`, since `target` and `source` are in
+ // the same block and the operand dominated `source` before.
+ if (!operandOp)
+ continue;
+
+ // Operand's defining operation should properly dominate `source`.
+ if (!domInfo.properlyDominates(operandOp, source,
+ /*enclosingOpOk=*/false))
+ return emitSilenceableFailure(target)
+ << "operands of target should be properly dominated by source";
+ }
+
+ // Check if values used by `target` are dominated by `source`.
+ bool failed = false;
+ OpOperand *failedValue = nullptr;
+ visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
+ if (!domInfo.properlyDominates(operand->getOwner(), source,
+ /*enclosingOpOk=*/false)) {
+ failed = true;
+ failedValue = operand;
+ }
+ });
+
+ if (failed)
+ return emitSilenceableFailure(failedValue->getOwner())
+ << "values used inside regions of target should be properly "
+ "dominated by source";
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Check if `target` can be fused into `source`.
+///
+/// This is a simple check that just checks if both loops have same
+/// bounds, steps and mapping. This check does not ensure that the side effects
+/// of `target` are independent of `source` or vice-versa. It is the
+/// responsibility of the caller to ensure that.
+static bool isForallWithIdenticalConfiguration(Operation *target,
+ Operation *source) {
+ auto targetOp = dyn_cast<scf::ForallOp>(target);
+ auto sourceOp = dyn_cast<scf::ForallOp>(source);
+ if (!targetOp || !sourceOp)
+ return false;
+
+ return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
+ targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
+ targetOp.getMixedStep() == sourceOp.getMixedStep() &&
+ targetOp.getMapping() == sourceOp.getMapping();
+}
+
+/// Fuse `target` into `source` assuming they are siblings and indepndent.
+/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
+static Operation *fuseSiblings(Operation *target, Operation *source,
+ RewriterBase &rewriter) {
+ auto targetOp = dyn_cast<scf::ForallOp>(target);
+ auto sourceOp = dyn_cast<scf::ForallOp>(source);
+ if (!targetOp || !sourceOp)
+ return nullptr;
+ return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
+}
+
+DiagnosedSilenceableFailure
+transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ auto sourceOps = state.getPayloadOps(getSource());
+
+ if (!llvm::hasSingleElement(targetOps) ||
+ !llvm::hasSingleElement(sourceOps)) {
+ return emitDefiniteFailure()
+ << "requires exactly one target handle (got "
+ << llvm::range_size(targetOps) << ") and exactly one "
+ << "source handle (got " << llvm::range_size(sourceOps) << ")";
+ }
+
+ Operation *target = *targetOps.begin();
+ Operation *source = *sourceOps.begin();
+
+ // Check if the target and source are siblings.
+ DiagnosedSilenceableFailure diag = isOpSibling(target, source);
+ if (!diag.succeeded())
+ return diag;
+
+ // Check if the target can be fused into source.
+ if (!isForallWithIdenticalConfiguration(target, source)) {
+ return emitSilenceableFailure(target->getLoc())
+ << "operations cannot be fused";
+ }
+
+ Operation *fusedLoop = fuseSiblings(target, source, rewriter);
+ assert(fusedLoop && "failed to fuse operations");
+
+ results.set(cast<OpResult>(getFusedLoop()), {fusedLoop});
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index b3e8ef7ef64365..9ac751f1915ab1 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -970,3 +970,68 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
return tileLoops;
}
+
+scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
+ scf::ForallOp source,
+ RewriterBase &rewriter) {
+ unsigned numTargetOuts = target.getNumResults();
+ unsigned numSourceOuts = source.getNumResults();
+
+ OperandRange targetOuts = target.getOutputs();
+ OperandRange sourceOuts = source.getOutputs();
+
+ // Create fused shared_outs.
+ SmallVector<Value> fusedOuts;
+ fusedOuts.reserve(numTargetOuts + numSourceOuts);
+ fusedOuts.append(targetOuts.begin(), targetOuts.end());
+ fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
+
+ // Create a new scf::forall op after the source loop.
+ rewriter.setInsertionPointAfter(source);
+ scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
+ source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
+ source.getMixedStep(), fusedOuts, source.getMapping());
+
+ // Map control operands.
+ IRMapping fusedMapping;
+ fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
+ fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+
+ // Map shared outs.
+ fusedMapping.map(target.getOutputBlockArguments(),
+ fusedLoop.getOutputBlockArguments().slice(0, numTargetOuts));
+ fusedMapping.map(
+ source.getOutputBlockArguments(),
+ fusedLoop.getOutputBlockArguments().slice(numTargetOuts, numSourceOuts));
+
+ // Append everything except the terminator into the fused operation.
+ rewriter.setInsertionPointToStart(fusedLoop.getBody());
+ for (Operation &op : target.getLoopBody().begin()->without_terminator())
+ rewriter.clone(op, fusedMapping);
+ for (Operation &op : source.getLoopBody().begin()->without_terminator())
+ rewriter.clone(op, fusedMapping);
+
+ // Fuse the old terminator in_parallel ops into the new one.
+ scf::InParallelOp targetTerm = target.getTerminator();
+ scf::InParallelOp sourceTerm = source.getTerminator();
+ scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
+
+ rewriter.setInsertionPointToStart(fusedTerm.getBody());
+ for (Operation &op : targetTerm.getYieldingOps())
+ rewriter.clone(op, fusedMapping);
+ for (Operation &op : sourceTerm.getYieldingOps())
+ rewriter.clone(op, fusedMapping);
+
+ // Replace all uses of the old loops with the fused loop.
+ rewriter.replaceAllUsesWith(target.getResults(),
+ fusedLoop.getResults().slice(0, numTargetOuts));
+ rewriter.replaceAllUsesWith(
+ source.getResults(),
+ fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
+
+ // Erase the old loops.
+ rewriter.eraseOp(target);
+ rewriter.eraseOp(source);
+
+ return fusedLoop;
+}
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
new file mode 100644
index 00000000000000..54e1b37ff3f6af
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+
+func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+ %zero = arith.constant 0.0 : f32
+ %out_alloc = tensor.empty() : tensor<128x128xf32>
+ %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+ // CHECK: [[T:%.*]] = affine.apply
+ // CHECK: tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1]
+ // CHECK: [[OUT1:%.*]] = linalg.matmul
+ // CHECK: tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1]
+ // CHECK: [[OUT2:%.*]] = linalg.matmul
+ // CHECK: scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice [[OUT1]] into [[S1]][[[T]], 0] [32, 128] [1, 1]
+ // CHECK: tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1]
+ // CHECK: }
+ // CHECK: }
+ %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%variant_op : !transform.any_op):
+ %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+ %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+ %zero = arith.constant 0.0 : f32
+ %out_alloc = tensor.empty() : tensor<128x128xf32>
+ %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ // expected-error @below {{user of results of target should be properly dominated by source}}
+ %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%variant_op : !transform.any_op):
+ %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+ %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+ %zero = arith.constant 0.0 : f32
+ %out_alloc = tensor.empty() : tensor<128x128xf32>
+ %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+ // expected-error @below {{values used inside regions of target should be properly dominated by source}}
+ %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%variant_op : !transform.any_op):
+ %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+ %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+ %zero = arith.constant 0.0 : f32
+ %out_alloc = tensor.empty() : tensor<128x128xf32>
+ %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+ // expected-error @below {{operands of target should be properly dominated by source}}
+ %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out1 : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%variant_op : !transform.any_op):
+ %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+ %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+}
More information about the Mlir-commits
mailing list