[Mlir-commits] [mlir] [SCF][Transform] Add support for scf.for in LoopFuseSibling op (PR #81495)
Rolf Morel
llvmlistbot at llvm.org
Thu Mar 28 03:37:18 PDT 2024
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/81495
>From f464d288c4106630946cc6a993cacf314dca8aca Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at huawei.com>
Date: Tue, 30 Jan 2024 00:14:57 +0800
Subject: [PATCH] [SCF][Transform] Add support for scf.for in LoopFuseSibling
op
Adds support for fusing two scf.for loops occurring in the same block.
Uses the rudimentary checks already in place for scf.for_all (like the target
loop's operands being dominated by the source loop).
- Fixes a bug in the dominance check whereby it was checked that values in the
target loop themselves dominated the source loop rather than the ops that
define these operands.
- Renames the LoopFuseSibling op to LoopFuseSiblingOp.
- Updates the LoopFuseSiblingOp's description.
- Adds tests for using LoopFuseSiblingOp on scf.for loops, including one which
fails without the fix for the dominance check.
- Adds tests checking the different failure modes of the dominance checker.
- Adds test for case whereby scf.yield is automatically generated when there
are no loop-carried variables.
---
.../SCF/TransformOps/SCFTransformOps.td | 23 +-
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 10 +
.../SCF/TransformOps/SCFTransformOps.cpp | 66 +++--
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 99 ++++---
.../SCF/transform-loop-fuse-sibling.mlir | 251 ++++++++++++++++--
5 files changed, 357 insertions(+), 92 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 6f94cee5b01911..5eefe2664d0a1b 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -333,23 +333,24 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
}];
}
-def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
+def LoopFuseSiblingOp : Op<Transform_Dialect, "loop.fuse_sibling",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Fuse a loop into another loop, assuming the fusion is legal.";
let description = [{
Fuses the `target` loop into the `source` loop assuming they are
- independent of each other. It is the responsibility of the user to ensure
- that the given two loops are independent of each other, this operation will
- not performa any legality checks and will simply fuse the two given loops.
+ independent of each other. In the fused loop, the arguments, body and
+ results of `target` are placed _before_ those of `source`.
- Currently, the only fusion supported is when both `target` and `source`
- are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
- mapping must match, otherwise a silencable failure is produced.
+ For fusion of two `scf.for` loops, the bounds and step size must match. For
+ fusion of two `scf.forall` loops, the bounds and the mapping must match.
+ Otherwise a silencable failure is produced.
- The input handles `target` and `source` must map to exactly one operation,
- a definite failure is produced otherwise.
+ The `target` and `source` handles must refer to exactly one operation,
+ otherwise a definite failure is produced. It is the responsibility of the
+ user to ensure that the `target` and `source` loops are independent of each
+ other -- this op will only perform rudimentary legality checks.
#### Return modes
@@ -362,10 +363,6 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
let results = (outs TransformHandleTypeInterface:$fused_loop);
let assemblyFormat = "$target `into` $source attr-dict "
" `:` functional-type(operands, results)";
-
- let builders = [
- OpBuilder<(ins "Value":$loop, "Value":$fused_loop)>
- ];
}
#endif // SCF_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 9bdd6eb833876f..883d11bcc4df06 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForallOp source,
RewriterBase &rewriter);
+/// Given two scf.for loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
+ RewriterBase &rewriter);
+
} // namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 4d8d93f7aac720..c0918414820803 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -384,7 +384,7 @@ void transform::TakeAssumedBranchOp::getEffects(
}
//===----------------------------------------------------------------------===//
-// LoopFuseSibling
+// LoopFuseSiblingOp
//===----------------------------------------------------------------------===//
/// Check if `target` and `source` are siblings, in the context that `target`
@@ -408,7 +408,7 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
// Check if fusion will violate dominance.
DominanceInfo domInfo(source);
if (target->isBeforeInBlock(source)) {
- // Since, `target` is before `source`, all users of results of `target`
+ // Since `target` is before `source`, all users of results of `target`
// need to be dominated by `source`.
for (Operation *user : target->getUsers()) {
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
@@ -424,9 +424,8 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
// Check if operands of `target` are dominated by `source`.
for (Value operand : target->getOperands()) {
Operation *operandOp = operand.getDefiningOp();
- // If operand does not have a defining operation, it is a block arguement,
- // which will always dominate `source`, since `target` and `source` are in
- // the same block and the operand dominated `source` before.
+ // Operands without defining operations are block arguments. When `target`
+ // and `source` occur in the same block, these operands dominate `source`.
if (!operandOp)
continue;
@@ -441,8 +440,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
bool failed = false;
OpOperand *failedValue = nullptr;
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
- if (!domInfo.properlyDominates(operand->getOwner(), source,
- /*enclosingOpOk=*/false)) {
+ Operation *operandOp = operand->get().getDefiningOp();
+ if (operandOp && !domInfo.properlyDominates(operandOp, source,
+ /*enclosingOpOk=*/false)) {
+ // `operand` is not an argument of an enclosing block and the defining
+ // op of `operand` is outside `target` but does not dominate `source`.
failed = true;
failedValue = operand;
}
@@ -457,12 +459,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
return DiagnosedSilenceableFailure::success();
}
-/// Check if `target` can be fused into `source`.
+/// Check if `target` scf.forall can be fused into `source` scf.forall.
///
-/// This is a simple check that just checks if both loops have same
-/// bounds, steps and mapping. This check does not ensure that the side effects
-/// of `target` are independent of `source` or vice-versa. It is the
-/// responsibility of the caller to ensure that.
+/// This simply checks if both loops have the same bounds, steps and mapping.
+/// No attempt is made at checking that the side effects of `target` and
+/// `source` are independent of each other.
static bool isForallWithIdenticalConfiguration(Operation *target,
Operation *source) {
auto targetOp = dyn_cast<scf::ForallOp>(target);
@@ -476,21 +477,27 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
targetOp.getMapping() == sourceOp.getMapping();
}
-/// Fuse `target` into `source` assuming they are siblings and indepndent.
-/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
-static Operation *fuseSiblings(Operation *target, Operation *source,
- RewriterBase &rewriter) {
- auto targetOp = dyn_cast<scf::ForallOp>(target);
- auto sourceOp = dyn_cast<scf::ForallOp>(source);
+/// Check if `target` scf.for can be fused into `source` scf.for.
+///
+/// This simply checks if both loops have the same bounds and steps. No attempt
+/// is made at checking that the side effects of `target` and `source` are
+/// independent of each other.
+static bool isForWithIdenticalConfiguration(Operation *target,
+ Operation *source) {
+ auto targetOp = dyn_cast<scf::ForOp>(target);
+ auto sourceOp = dyn_cast<scf::ForOp>(source);
if (!targetOp || !sourceOp)
- return nullptr;
- return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
+ return false;
+
+ return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+ targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+ targetOp.getStep() == sourceOp.getStep();
}
DiagnosedSilenceableFailure
-transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
- transform::TransformResults &results,
- transform::TransformState &state) {
+transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
auto targetOps = state.getPayloadOps(getTarget());
auto sourceOps = state.getPayloadOps(getSource());
@@ -510,13 +517,18 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
if (!diag.succeeded())
return diag;
- // Check if the target can be fused into source.
- if (!isForallWithIdenticalConfiguration(target, source)) {
+ Operation *fusedLoop;
+ /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
+ if (isForWithIdenticalConfiguration(target, source)) {
+ fusedLoop = fuseIndependentSiblingForLoops(
+ cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
+ } else if (isForallWithIdenticalConfiguration(target, source)) {
+ fusedLoop = fuseIndependentSiblingForallLoops(
+ cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
+ } else
return emitSilenceableFailure(target->getLoc())
<< "operations cannot be fused";
- }
- Operation *fusedLoop = fuseSiblings(target, source, rewriter);
assert(fusedLoop && "failed to fuse operations");
results.set(cast<OpResult>(getFusedLoop()), {fusedLoop});
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 502d7e197a6f6b..914aeb4fa79fda 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -910,61 +910,98 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
unsigned numTargetOuts = target.getNumResults();
unsigned numSourceOuts = source.getNumResults();
- OperandRange targetOuts = target.getOutputs();
- OperandRange sourceOuts = source.getOutputs();
-
// Create fused shared_outs.
SmallVector<Value> fusedOuts;
- fusedOuts.reserve(numTargetOuts + numSourceOuts);
- fusedOuts.append(targetOuts.begin(), targetOuts.end());
- fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
+ llvm::append_range(fusedOuts, target.getOutputs());
+ llvm::append_range(fusedOuts, source.getOutputs());
- // Create a new scf::forall op after the source loop.
+ // Create a new scf.forall op after the source loop.
rewriter.setInsertionPointAfter(source);
scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
source.getMixedStep(), fusedOuts, source.getMapping());
// Map control operands.
- IRMapping fusedMapping;
- fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
- fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+ IRMapping mapping;
+ mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
+ mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
// Map shared outs.
- fusedMapping.map(target.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
- fusedMapping.map(
- source.getRegionIterArgs(),
- fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
+ mapping.map(target.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+ mapping.map(source.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
// Append everything except the terminator into the fused operation.
rewriter.setInsertionPointToStart(fusedLoop.getBody());
for (Operation &op : target.getBody()->without_terminator())
- rewriter.clone(op, fusedMapping);
+ rewriter.clone(op, mapping);
for (Operation &op : source.getBody()->without_terminator())
- rewriter.clone(op, fusedMapping);
+ rewriter.clone(op, mapping);
// Fuse the old terminator in_parallel ops into the new one.
scf::InParallelOp targetTerm = target.getTerminator();
scf::InParallelOp sourceTerm = source.getTerminator();
scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-
rewriter.setInsertionPointToStart(fusedTerm.getBody());
for (Operation &op : targetTerm.getYieldingOps())
- rewriter.clone(op, fusedMapping);
+ rewriter.clone(op, mapping);
for (Operation &op : sourceTerm.getYieldingOps())
- rewriter.clone(op, fusedMapping);
-
- // Replace all uses of the old loops with the fused loop.
- rewriter.replaceAllUsesWith(target.getResults(),
- fusedLoop.getResults().slice(0, numTargetOuts));
- rewriter.replaceAllUsesWith(
- source.getResults(),
- fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
-
- // Erase the old loops.
- rewriter.eraseOp(target);
- rewriter.eraseOp(source);
+ rewriter.clone(op, mapping);
+
+ // Replace old loops by substituting their uses by results of the fused loop.
+ rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+ rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+
+ return fusedLoop;
+}
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+ scf::ForOp source,
+ RewriterBase &rewriter) {
+ unsigned numTargetOuts = target.getNumResults();
+ unsigned numSourceOuts = source.getNumResults();
+
+ // Create fused init_args, with target's init_args before source's init_args.
+ SmallVector<Value> fusedInitArgs;
+ llvm::append_range(fusedInitArgs, target.getInitArgs());
+ llvm::append_range(fusedInitArgs, source.getInitArgs());
+
+ // Create a new scf.for op after the source loop (with scf.yield terminator
+ // (without arguments) only in case its init_args is empty).
+ rewriter.setInsertionPointAfter(source);
+ scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+ source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+ source.getStep(), fusedInitArgs);
+
+ // Map original induction variables and operands to those of the fused loop.
+ IRMapping mapping;
+ mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
+ mapping.map(target.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+ mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
+ mapping.map(source.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+
+ // Merge target's body into the new (fused) for loop and then source's body.
+ rewriter.setInsertionPointToStart(fusedLoop.getBody());
+ for (Operation &op : target.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+ for (Operation &op : source.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+
+ // Build fused yield results by appropriately mapping original yield operands.
+ SmallVector<Value> yieldResults;
+ for (Value operand : target.getBody()->getTerminator()->getOperands())
+ yieldResults.push_back(mapping.lookupOrDefault(operand));
+ for (Value operand : source.getBody()->getTerminator()->getOperands())
+ yieldResults.push_back(mapping.lookupOrDefault(operand));
+ if (!yieldResults.empty())
+ rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+
+ // Replace old loops by substituting their uses by results of the fused loop.
+ rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+ rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
return fusedLoop;
}
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index faaa2db3aa57de..0f51b1cdbe0cf1 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -1,14 +1,113 @@
// RUN: mlir-opt %s -transform-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s --check-prefix CHECK-NOCLEANUP
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+// CHECK: func.func @fuse_1st_for_into_2nd([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @fuse_1st_for_into_2nd(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+ // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+ // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+ // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK: [[R0:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IA:%.*]] = [[A]], [[IB:%.*]] = [[B]]) {{.*}}
+ %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+ // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IA]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IA]][[[IV]]]
+ %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %5 = arith.addf %3, %2 : vector<16xf32>
+ %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ scf.yield %6 : tensor<128xf32>
+ }
+ %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IB]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB]][[[IV]]]
+ %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+ %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+ scf.yield %dup6 : tensor<128xf32>
+ }
+ return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %for:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+ // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+ // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+ // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK: [[R0:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB:%.*]] = [[B]], [[IA:%.*]] = [[A]]) {{.*}}
+ %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+ // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IB]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB]][[[IV]]]
+ %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %5 = arith.addf %3, %2 : vector<16xf32>
+ %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ scf.yield %6 : tensor<128xf32>
+ }
+ %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IA]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IA]][[[IV]]]
+ %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ // NB: the dominance check used to fail on the following line,
+ // however the defining op for the value of %arg3 occurs above the source loop and hence is safe
+ // and %arg4 is a block argument of the scope of the loops and hence is safe
+ %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+ %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+ scf.yield %dup6 : tensor<128xf32>
+ }
+ return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %for:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: func.func @matmul_fuse_1st_forall_into_2nd([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @matmul_fuse_1st_forall_into_2nd(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
%zero = arith.constant 0.0 : f32
%out_alloc = tensor.empty() : tensor<128x128xf32>
%out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
// CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
// CHECK: [[T:%.*]] = affine.apply
+ // CHECK: tensor.extract_slice [[A2]][[[T]], 0] [32, 128] [1, 1]
// CHECK: tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1]
// CHECK: [[OUT1:%.*]] = linalg.matmul
+ // CHECK: tensor.extract_slice [[A1]][[[T]], 0] [32, 128] [1, 1]
// CHECK: tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1]
// CHECK: [[OUT2:%.*]] = linalg.matmul
// CHECK: scf.forall.in_parallel {
@@ -16,12 +115,11 @@ func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tenso
// CHECK: tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1]
// CHECK: }
// CHECK: }
- %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
- %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
}
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
%matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -31,25 +129,37 @@ module attributes {transform.with_named_sequence} {
%tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+ %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+// CHECK: func.func @matmul_fuse_2nd_forall_into_1st([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @matmul_fuse_2nd_forall_into_1st(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
%zero = arith.constant 0.0 : f32
%out_alloc = tensor.empty() : tensor<128x128xf32>
%out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
- // expected-error @below {{user of results of target should be properly dominated by source}}
- %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
- %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+ // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+ // CHECK: [[T:%.*]] = affine.apply
+ // CHECK: tensor.extract_slice [[A1]][[[T]], 0] [32, 128] [1, 1]
+ // CHECK: tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1]
+ // CHECK: [[OUT1:%.*]] = linalg.matmul
+ // CHECK: tensor.extract_slice [[A2]][[[T]], 0] [32, 128] [1, 1]
+ // CHECK: tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1]
+ // CHECK: [[OUT2:%.*]] = linalg.matmul
+ // CHECK: scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice [[OUT1]] into [[S1]][[[T]], 0] [32, 128] [1, 1]
+ // CHECK: tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1]
+ // CHECK: }
+ // CHECK: }
+ %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
}
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
%matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -66,18 +176,84 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+// CHECK-NOCLEANUP: func.func @fuse_no_iter_args([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @fuse_no_iter_args(%A: tensor<128xf32>, %B: tensor<128xf32>) {
+ // CHECK-NOCLEANUP: [[C0:%.*]] = arith.constant 0 : index
+ // CHECK-NOCLEANUP: [[C16:%.*]] = arith.constant 16 : index
+ // CHECK-NOCLEANUP: [[C128:%.*]] = arith.constant 128 : index
+ // CHECK-NOCLEANUP: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK-NOCLEANUP: scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] {{.*}}
+ scf.for %arg0 = %c0 to %c128 step %c16 {
+ // CHECK-NOCLEANUP: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+ %2 = vector.transfer_read %A[%arg0], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ scf.yield
+ }
+ scf.for %arg0 = %c0 to %c128 step %c16 {
+ // CHECK-NOCLEANUP: [[BSLICE:%.*]] = vector.transfer_read [[B]][[[IV]]], [[ZERO]]
+ %dup2 = vector.transfer_read %B[%arg0], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ scf.yield
+ }
+ return
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %for:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ // expected-error @below {{user of results of target should be properly dominated by source}}
+ %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+ %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %5 = arith.addf %3, %2 : vector<16xf32>
+ %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ scf.yield %6 : tensor<128xf32>
+ }
+ %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %1) -> (tensor<128xf32>) {
+ %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+ %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ scf.yield %dup6 : tensor<128xf32>
+ }
+ return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %for:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @source_forall_uses_result_of_target_forall_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
%zero = arith.constant 0.0 : f32
%out_alloc = tensor.empty() : tensor<128x128xf32>
%out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+ // expected-error @below {{user of results of target should be properly dominated by source}}
%out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
- // expected-error @below {{values used inside regions of target should be properly dominated by source}}
%out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
}
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
%matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -87,25 +263,58 @@ module attributes {transform.with_named_sequence} {
%tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+ %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
- %zero = arith.constant 0.0 : f32
- %out_alloc = tensor.empty() : tensor<128x128xf32>
- %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+ %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %5 = arith.addf %3, %2 : vector<16xf32>
+ %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ scf.yield %6 : tensor<128xf32>
+ }
+ %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+ // expected-error @below {{values used inside regions of target should be properly dominated by source}}
+ %dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+ %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ scf.yield %dup6 : tensor<128xf32>
+ }
+ return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %for:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
- %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+// -----
+
+func.func @target_forall_depends_on_value_not_dominated_by_source_forall_err(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+ %zero = arith.constant 0.0 : f32
+ %buf1_alloc = tensor.empty() : tensor<128x128xf32>
+ %buf1 = linalg.fill ins(%zero : f32) outs(%buf1_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%buf1 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %out_alloc2 = tensor.empty() : tensor<128x128xf32>
+ %buf2 = linalg.fill ins(%zero : f32) outs(%buf1_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
// expected-error @below {{operands of target should be properly dominated by source}}
- %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out1 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%buf2 : tensor<128x128xf32>) -> tensor<128x128xf32>
func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
}
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
%matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
More information about the Mlir-commits
mailing list