[Mlir-commits] [mlir] [SCF][Transform] Add support for scf.for in LoopFuseSibling op (PR #81495)

Rolf Morel llvmlistbot at llvm.org
Fri Mar 22 11:41:11 PDT 2024


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/81495

>From 69e562025746835889da4c5d219dd176c458b683 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at huawei.com>
Date: Tue, 30 Jan 2024 00:14:57 +0800
Subject: [PATCH] [SCF][Transform] Add support for scf.for in LoopFuseSibling
 op

Adds support for fusing two scf.for loops occurring in the same block.
Implementation mirrors that of LoopFuseSibling's support for scf.forall,
including only rudimentary checks, like the target loop's operands being
dominated by the source loop.

Fixes a bug in the dominance check whereby it was checked that values in the
target loop themselves dominated the source loop rather than (the ops) where
these values originate.

Adds tests for using LoopFuseSibling on scf.for loops, including one which
fails without the fix for the dominance check.
---
 .../SCF/TransformOps/SCFTransformOps.td       |  24 +-
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  10 +
 .../SCF/TransformOps/SCFTransformOps.cpp      |  59 +++--
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |  97 +++++---
 .../SCF/transform-loop-fuse-sibling.mlir      | 218 ++++++++++++++++--
 5 files changed, 324 insertions(+), 84 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 6f94cee5b01911..120cfaba1fe15b 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -333,23 +333,25 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
   }];
 }
 
-def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
+def LoopFuseSiblingOp : Op<Transform_Dialect, "loop.fuse_sibling",
   [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
    DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let summary = "Fuse a loop into another loop, assuming the fusion is legal.";
 
   let description = [{
     Fuses the `target` loop into the `source` loop assuming they are
-    independent of each other. It is the responsibility of the user to ensure
-    that the given two loops are independent of each other, this operation will
-    not performa any legality checks and will simply fuse the two given loops.
+    independent of each other. In the fused loop, the arguments, body and
+    results of `target` are placed _before_ those of `source`.
 
-    Currently, the only fusion supported is when both `target` and `source`
-    are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
-    mapping must match, otherwise a silencable failure is produced.
+    For fusion of two `scf.for` loops, the bounds and step size must match. For
+    fusion of two `scf.forall` loops, the bounds and the mapping must match.
+    Otherwise a silencable failure is produced. Attempting to fuse any other kinds
+    of loops/ops will produce a definite failure.
 
-    The input handles `target` and `source` must map to exactly one operation,
-    a definite failure is produced otherwise.
+    The `target` and `source` handles must refer to exactly one operation,
+    otherwise a definite failure is produced. It is the responsibility of the
+    user to ensure that the `target` and `source` loops are independent of each
+    other -- this op will not perform any legality checks.
 
     #### Return modes
 
@@ -362,10 +364,6 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
   let results = (outs TransformHandleTypeInterface:$fused_loop);
   let assemblyFormat = "$target `into` $source attr-dict "
                        " `:` functional-type(operands, results)";
-
-  let builders = [
-    OpBuilder<(ins "Value":$loop, "Value":$fused_loop)>
-  ];
 }
 
 #endif // SCF_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 9bdd6eb833876f..883d11bcc4df06 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                 scf::ForallOp source,
                                                 RewriterBase &rewriter);
 
+/// Given two scf.for loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
+                                          RewriterBase &rewriter);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 4d8d93f7aac720..04228691ab5305 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -384,7 +384,7 @@ void transform::TakeAssumedBranchOp::getEffects(
 }
 
 //===----------------------------------------------------------------------===//
-// LoopFuseSibling
+// LoopFuseSiblingOp
 //===----------------------------------------------------------------------===//
 
 /// Check if `target` and `source` are siblings, in the context that `target`
@@ -408,7 +408,7 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
   // Check if fusion will violate dominance.
   DominanceInfo domInfo(source);
   if (target->isBeforeInBlock(source)) {
-    // Since, `target` is before `source`, all users of results of `target`
+    // Since `target` is before `source`, all users of results of `target`
     // need to be dominated by `source`.
     for (Operation *user : target->getUsers()) {
       if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
@@ -424,9 +424,8 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
     // Check if operands of `target` are dominated by `source`.
     for (Value operand : target->getOperands()) {
       Operation *operandOp = operand.getDefiningOp();
-      // If operand does not have a defining operation, it is a block arguement,
-      // which will always dominate `source`, since `target` and `source` are in
-      // the same block and the operand dominated `source` before.
+      // Operands without defining operations are block arguments. When `target`
+      // and `source` occur in the same block, these operands dominate `source`.
       if (!operandOp)
         continue;
 
@@ -441,8 +440,12 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
     bool failed = false;
     OpOperand *failedValue = nullptr;
     visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
-      if (!domInfo.properlyDominates(operand->getOwner(), source,
-                                     /*enclosingOpOk=*/false)) {
+      Operation *operandOp = operand->get().getDefiningOp();
+      if (!operandOp && !domInfo.properlyDominates(operandOp, source,
+                                                   /*enclosingOpOk=*/false)) {
+        // `operand` is not a block argument of an enclosing block or otherwise
+        // `operand`'s defining op is outside `target` but does not dominate
+        // `source`.
         failed = true;
         failedValue = operand;
       }
@@ -476,21 +479,40 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
          targetOp.getMapping() == sourceOp.getMapping();
 }
 
-/// Fuse `target` into `source` assuming they are siblings and indepndent.
-/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
+static bool isForWithIdenticalConfiguration(Operation *target,
+                                            Operation *source) {
+  auto targetOp = dyn_cast<scf::ForOp>(target);
+  auto sourceOp = dyn_cast<scf::ForOp>(source);
+  if (!targetOp || !sourceOp)
+    return false;
+
+  return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+         targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+         targetOp.getStep() == sourceOp.getStep();
+}
+
+/// Fuse `target` into `source` assuming they are siblings and independent.
+/// TODO: Support fusion for operations besides scf.for and scf.forall.
 static Operation *fuseSiblings(Operation *target, Operation *source,
                                RewriterBase &rewriter) {
-  auto targetOp = dyn_cast<scf::ForallOp>(target);
-  auto sourceOp = dyn_cast<scf::ForallOp>(source);
-  if (!targetOp || !sourceOp)
-    return nullptr;
-  return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
+  auto targetForOp = dyn_cast<scf::ForOp>(target);
+  auto sourceForOp = dyn_cast<scf::ForOp>(source);
+  if (targetForOp && sourceForOp)
+    return fuseIndependentSiblingForLoops(targetForOp, sourceForOp, rewriter);
+
+  auto targetForallOp = dyn_cast<scf::ForallOp>(target);
+  auto sourceForallOp = dyn_cast<scf::ForallOp>(source);
+  if (targetForallOp && sourceForallOp)
+    return fuseIndependentSiblingForallLoops(targetForallOp, sourceForallOp,
+                                             rewriter);
+
+  return nullptr;
 }
 
 DiagnosedSilenceableFailure
-transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
-                                  transform::TransformResults &results,
-                                  transform::TransformState &state) {
+transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
   auto targetOps = state.getPayloadOps(getTarget());
   auto sourceOps = state.getPayloadOps(getSource());
 
@@ -511,7 +533,8 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
     return diag;
 
   // Check if the target can be fused into source.
-  if (!isForallWithIdenticalConfiguration(target, source)) {
+  if (!isForallWithIdenticalConfiguration(target, source) &&
+      !isForWithIdenticalConfiguration(target, source)) {
     return emitSilenceableFailure(target->getLoc())
            << "operations cannot be fused";
   }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 502d7e197a6f6b..ab04b16b27b0df 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -910,61 +910,96 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
   unsigned numTargetOuts = target.getNumResults();
   unsigned numSourceOuts = source.getNumResults();
 
-  OperandRange targetOuts = target.getOutputs();
-  OperandRange sourceOuts = source.getOutputs();
-
   // Create fused shared_outs.
   SmallVector<Value> fusedOuts;
-  fusedOuts.reserve(numTargetOuts + numSourceOuts);
-  fusedOuts.append(targetOuts.begin(), targetOuts.end());
-  fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
+  llvm::append_range(fusedOuts, target.getOutputs());
+  llvm::append_range(fusedOuts, source.getOutputs());
 
-  // Create a new scf::forall op after the source loop.
+  // Create a new scf.forall op after the source loop.
   rewriter.setInsertionPointAfter(source);
   scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
       source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
       source.getMixedStep(), fusedOuts, source.getMapping());
 
   // Map control operands.
-  IRMapping fusedMapping;
-  fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
-  fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+  IRMapping mapping;
+  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
+  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
 
   // Map shared outs.
-  fusedMapping.map(target.getRegionIterArgs(),
-                   fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
-  fusedMapping.map(
-      source.getRegionIterArgs(),
-      fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
+  mapping.map(target.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+  mapping.map(source.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
 
   // Append everything except the terminator into the fused operation.
   rewriter.setInsertionPointToStart(fusedLoop.getBody());
   for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
   for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
 
   // Fuse the old terminator in_parallel ops into the new one.
   scf::InParallelOp targetTerm = target.getTerminator();
   scf::InParallelOp sourceTerm = source.getTerminator();
   scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-
   rewriter.setInsertionPointToStart(fusedTerm.getBody());
   for (Operation &op : targetTerm.getYieldingOps())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
   for (Operation &op : sourceTerm.getYieldingOps())
-    rewriter.clone(op, fusedMapping);
-
-  // Replace all uses of the old loops with the fused loop.
-  rewriter.replaceAllUsesWith(target.getResults(),
-                              fusedLoop.getResults().slice(0, numTargetOuts));
-  rewriter.replaceAllUsesWith(
-      source.getResults(),
-      fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
-
-  // Erase the old loops.
-  rewriter.eraseOp(target);
-  rewriter.eraseOp(source);
+    rewriter.clone(op, mapping);
+
+  // Replace old loops by substituting their uses by results of the fused loop.
+  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+
+  return fusedLoop;
+}
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+                                                scf::ForOp source,
+                                                RewriterBase &rewriter) {
+  unsigned numTargetOuts = target.getNumResults();
+  unsigned numSourceOuts = source.getNumResults();
+
+  // Create fused init_args, with target's init_args before source's init_args.
+  SmallVector<Value> fusedInitArgs;
+  llvm::append_range(fusedInitArgs, target.getInitArgs());
+  llvm::append_range(fusedInitArgs, source.getInitArgs());
+
+  // Create a new scf.for op after the source loop.
+  rewriter.setInsertionPointAfter(source);
+  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+      source.getStep(), fusedInitArgs);
+
+  // Map original induction variables and operands to those of the fused loop.
+  IRMapping mapping;
+  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
+  mapping.map(target.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
+  mapping.map(source.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+
+  // Merge target's body into the new (fused) for loop and then source's body.
+  rewriter.setInsertionPointToStart(fusedLoop.getBody());
+  for (Operation &op : target.getBody()->without_terminator())
+    rewriter.clone(op, mapping);
+  for (Operation &op : source.getBody()->without_terminator())
+    rewriter.clone(op, mapping);
+
+  // Build fused yield results by appropriately mapping original yield operands.
+  SmallVector<Value> yieldResults;
+  for (Value operand : target.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  for (Value operand : source.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+
+  // Replace old loops by substituting their uses by results of the fused loop.
+  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
 
   return fusedLoop;
 }
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index faaa2db3aa57de..e96a11a99144d3 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -1,14 +1,112 @@
 // RUN: mlir-opt %s -transform-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+// CHECK: func.func @fuse_1st_for_into_2nd([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @fuse_1st_for_into_2nd(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[R0:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IA:%.*]] = [[A]], [[IB:%.*]] = [[B]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IA]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IA]][[[IV]]]
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IB]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB]][[[IV]]]
+    %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[R0:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB:%.*]] = [[B]], [[IA:%.*]] = [[A]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IB]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB]][[[IV]]]
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IA]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IA]][[[IV]]]
+    %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+  // NB: the dominance check used to fail on the following line,
+  // however the defining op for the value of %arg3 occurs above the source loop and hence is safe
+  // and %arg4 is a block argument of the scope of the loops and hence is safe
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK: func.func @matmul_fuse_1st_forall_into_2nd([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @matmul_fuse_1st_forall_into_2nd(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   // CHECK:   [[T:%.*]] = affine.apply
+  // CHECK:   tensor.extract_slice [[A2]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   [[OUT1:%.*]] = linalg.matmul
+  // CHECK:   tensor.extract_slice [[A1]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   [[OUT2:%.*]] = linalg.matmul
   // CHECK:   scf.forall.in_parallel {
@@ -16,12 +114,11 @@ func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tenso
   // CHECK:     tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   }
   // CHECK: }
-  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-  %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
 }
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
     %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -31,25 +128,37 @@ module attributes {transform.with_named_sequence} {
     %tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
-    %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+// CHECK: func.func @matmul_fuse_2nd_forall_into_1st([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @matmul_fuse_2nd_forall_into_1st(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
 
-  // expected-error @below {{user of results of target should be properly dominated by source}}
-  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-  %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  // CHECK:   [[T:%.*]] = affine.apply
+  // CHECK:   tensor.extract_slice [[A1]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   [[OUT1:%.*]] = linalg.matmul
+  // CHECK:   tensor.extract_slice [[A2]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   [[OUT2:%.*]] = linalg.matmul
+  // CHECK:   scf.forall.in_parallel {
+  // CHECK:     tensor.parallel_insert_slice [[OUT1]] into [[S1]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:     tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   }
+  // CHECK: }
+  %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
 }
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
     %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -66,18 +175,50 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // expected-error @below {{user of results of target should be properly dominated by source}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %1) -> (tensor<128xf32>) {
+    %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @source_forall_uses_result_of_target_forall_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
 
+  // expected-error @below {{user of results of target should be properly dominated by source}}
   %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-  // expected-error @below {{values used inside regions of target should be properly dominated by source}}
   %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
 }
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
     %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -87,25 +228,58 @@ module attributes {transform.with_named_sequence} {
     %tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
-    %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
-  %zero = arith.constant 0.0 : f32
-  %out_alloc = tensor.empty() : tensor<128x128xf32>
-  %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+  // expected-error @below {{values used inside regions of target should be properly dominated by source}}
+    %dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
 
-  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+// -----
+
+func.func @target_forall_depends_on_value_not_dominated_by_source_forall_err(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  %zero = arith.constant 0.0 : f32
+  %buf1_alloc = tensor.empty() : tensor<128x128xf32>
+  %buf1 = linalg.fill ins(%zero : f32) outs(%buf1_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%buf1 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out_alloc2 = tensor.empty() : tensor<128x128xf32>
+  %buf2 = linalg.fill ins(%zero : f32) outs(%buf1_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
   // expected-error @below {{operands of target should be properly dominated by source}}
-  %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out1 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%buf2 : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
 }
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
     %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -118,4 +292,4 @@ module attributes {transform.with_named_sequence} {
     %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
     transform.yield
   }
-}
+}
\ No newline at end of file



More information about the Mlir-commits mailing list