[Mlir-commits] [mlir] [SCF][Transform] Add support for scf.for in LoopFuseSibling op (PR #81495)

Rolf Morel llvmlistbot at llvm.org
Mon Mar 25 04:30:59 PDT 2024


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/81495

>From 98e2ab477f23528a54c037a0055ce811ac356c28 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at huawei.com>
Date: Tue, 30 Jan 2024 00:14:57 +0800
Subject: [PATCH] [SCF][Transform] Add support for scf.for in LoopFuseSibling
 op

Adds support for fusing two scf.for loops occurring in the same block.
Uses the rudimentary checks already in place for scf.for_all (like the target
loop's operands being dominated by the source loop).

- Fixes a bug in the dominance check whereby it was checked that values in the
  target loop themselves dominated the source loop rather than the ops that
  define these operands.
- Renames the LoopFuseSibling op to LoopFuseSiblingOp.
- Updates the LoopFuseSiblingOp's description.
- Adds tests for using LoopFuseSiblingOp on scf.for loops, including one which
fails without the fix for the dominance check.
- Adds tests checking the different failure modes of the dominance checker.
- Adds test for case whereby scf.yield is automatically generated when there
  are no loop-carried variables.
---
 .../SCF/TransformOps/SCFTransformOps.td       |  23 +-
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  10 +
 .../SCF/TransformOps/SCFTransformOps.cpp      |  58 ++--
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |  99 ++++---
 .../SCF/transform-loop-fuse-sibling.mlir      | 253 ++++++++++++++++--
 5 files changed, 359 insertions(+), 84 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 6f94cee5b01911..5eefe2664d0a1b 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -333,23 +333,24 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
   }];
 }
 
-def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
+def LoopFuseSiblingOp : Op<Transform_Dialect, "loop.fuse_sibling",
   [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
    DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let summary = "Fuse a loop into another loop, assuming the fusion is legal.";
 
   let description = [{
     Fuses the `target` loop into the `source` loop assuming they are
-    independent of each other. It is the responsibility of the user to ensure
-    that the given two loops are independent of each other, this operation will
-    not performa any legality checks and will simply fuse the two given loops.
+    independent of each other. In the fused loop, the arguments, body and
+    results of `target` are placed _before_ those of `source`.
 
-    Currently, the only fusion supported is when both `target` and `source`
-    are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
-    mapping must match, otherwise a silencable failure is produced.
+    For fusion of two `scf.for` loops, the bounds and step size must match. For
+    fusion of two `scf.forall` loops, the bounds and the mapping must match.
+    Otherwise a silencable failure is produced.
 
-    The input handles `target` and `source` must map to exactly one operation,
-    a definite failure is produced otherwise.
+    The `target` and `source` handles must refer to exactly one operation,
+    otherwise a definite failure is produced. It is the responsibility of the
+    user to ensure that the `target` and `source` loops are independent of each
+    other -- this op will only perform rudimentary legality checks.
 
     #### Return modes
 
@@ -362,10 +363,6 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
   let results = (outs TransformHandleTypeInterface:$fused_loop);
   let assemblyFormat = "$target `into` $source attr-dict "
                        " `:` functional-type(operands, results)";
-
-  let builders = [
-    OpBuilder<(ins "Value":$loop, "Value":$fused_loop)>
-  ];
 }
 
 #endif // SCF_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 9bdd6eb833876f..883d11bcc4df06 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                 scf::ForallOp source,
                                                 RewriterBase &rewriter);
 
+/// Given two scf.for loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
+                                          RewriterBase &rewriter);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 4d8d93f7aac720..45460ac27cee6c 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -384,7 +384,7 @@ void transform::TakeAssumedBranchOp::getEffects(
 }
 
 //===----------------------------------------------------------------------===//
-// LoopFuseSibling
+// LoopFuseSiblingOp
 //===----------------------------------------------------------------------===//
 
 /// Check if `target` and `source` are siblings, in the context that `target`
@@ -408,7 +408,7 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
   // Check if fusion will violate dominance.
   DominanceInfo domInfo(source);
   if (target->isBeforeInBlock(source)) {
-    // Since, `target` is before `source`, all users of results of `target`
+    // Since `target` is before `source`, all users of results of `target`
     // need to be dominated by `source`.
     for (Operation *user : target->getUsers()) {
       if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
@@ -424,9 +424,8 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
     // Check if operands of `target` are dominated by `source`.
     for (Value operand : target->getOperands()) {
       Operation *operandOp = operand.getDefiningOp();
-      // If operand does not have a defining operation, it is a block arguement,
-      // which will always dominate `source`, since `target` and `source` are in
-      // the same block and the operand dominated `source` before.
+      // Operands without defining operations are block arguments. When `target`
+      // and `source` occur in the same block, these operands dominate `source`.
       if (!operandOp)
         continue;
 
@@ -441,8 +440,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
     bool failed = false;
     OpOperand *failedValue = nullptr;
     visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
-      if (!domInfo.properlyDominates(operand->getOwner(), source,
-                                     /*enclosingOpOk=*/false)) {
+      Operation *operandOp = operand->get().getDefiningOp();
+      if (operandOp && !domInfo.properlyDominates(operandOp, source,
+                                                  /*enclosingOpOk=*/false)) {
+        // `operand` is not an argument of an enclosing block and the defining
+        // op of `operand` is outside `target` but does not dominate `source`.
         failed = true;
         failedValue = operand;
       }
@@ -476,21 +478,40 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
          targetOp.getMapping() == sourceOp.getMapping();
 }
 
-/// Fuse `target` into `source` assuming they are siblings and indepndent.
-/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
+static bool isForWithIdenticalConfiguration(Operation *target,
+                                            Operation *source) {
+  auto targetOp = dyn_cast<scf::ForOp>(target);
+  auto sourceOp = dyn_cast<scf::ForOp>(source);
+  if (!targetOp || !sourceOp)
+    return false;
+
+  return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+         targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+         targetOp.getStep() == sourceOp.getStep();
+}
+
+/// Fuse `target` into `source` assuming they are siblings and independent.
+/// TODO: Support fusion for operations besides scf.for and scf.forall.
 static Operation *fuseSiblings(Operation *target, Operation *source,
                                RewriterBase &rewriter) {
-  auto targetOp = dyn_cast<scf::ForallOp>(target);
-  auto sourceOp = dyn_cast<scf::ForallOp>(source);
-  if (!targetOp || !sourceOp)
-    return nullptr;
-  return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
+  auto targetForOp = dyn_cast<scf::ForOp>(target);
+  auto sourceForOp = dyn_cast<scf::ForOp>(source);
+  if (targetForOp && sourceForOp)
+    return fuseIndependentSiblingForLoops(targetForOp, sourceForOp, rewriter);
+
+  auto targetForallOp = dyn_cast<scf::ForallOp>(target);
+  auto sourceForallOp = dyn_cast<scf::ForallOp>(source);
+  if (targetForallOp && sourceForallOp)
+    return fuseIndependentSiblingForallLoops(targetForallOp, sourceForallOp,
+                                             rewriter);
+
+  return nullptr;
 }
 
 DiagnosedSilenceableFailure
-transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
-                                  transform::TransformResults &results,
-                                  transform::TransformState &state) {
+transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
   auto targetOps = state.getPayloadOps(getTarget());
   auto sourceOps = state.getPayloadOps(getSource());
 
@@ -511,7 +532,8 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
     return diag;
 
   // Check if the target can be fused into source.
-  if (!isForallWithIdenticalConfiguration(target, source)) {
+  if (!isForallWithIdenticalConfiguration(target, source) &&
+      !isForWithIdenticalConfiguration(target, source)) {
     return emitSilenceableFailure(target->getLoc())
            << "operations cannot be fused";
   }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 502d7e197a6f6b..346eb8e1af6c67 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -910,61 +910,98 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
   unsigned numTargetOuts = target.getNumResults();
   unsigned numSourceOuts = source.getNumResults();
 
-  OperandRange targetOuts = target.getOutputs();
-  OperandRange sourceOuts = source.getOutputs();
-
   // Create fused shared_outs.
   SmallVector<Value> fusedOuts;
-  fusedOuts.reserve(numTargetOuts + numSourceOuts);
-  fusedOuts.append(targetOuts.begin(), targetOuts.end());
-  fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
+  llvm::append_range(fusedOuts, target.getOutputs());
+  llvm::append_range(fusedOuts, source.getOutputs());
 
-  // Create a new scf::forall op after the source loop.
+  // Create a new scf.forall op after the source loop.
   rewriter.setInsertionPointAfter(source);
   scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
       source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
       source.getMixedStep(), fusedOuts, source.getMapping());
 
   // Map control operands.
-  IRMapping fusedMapping;
-  fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
-  fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+  IRMapping mapping;
+  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
+  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
 
   // Map shared outs.
-  fusedMapping.map(target.getRegionIterArgs(),
-                   fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
-  fusedMapping.map(
-      source.getRegionIterArgs(),
-      fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
+  mapping.map(target.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+  mapping.map(source.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
 
   // Append everything except the terminator into the fused operation.
   rewriter.setInsertionPointToStart(fusedLoop.getBody());
   for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
   for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
 
   // Fuse the old terminator in_parallel ops into the new one.
   scf::InParallelOp targetTerm = target.getTerminator();
   scf::InParallelOp sourceTerm = source.getTerminator();
   scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-
   rewriter.setInsertionPointToStart(fusedTerm.getBody());
   for (Operation &op : targetTerm.getYieldingOps())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
   for (Operation &op : sourceTerm.getYieldingOps())
-    rewriter.clone(op, fusedMapping);
-
-  // Replace all uses of the old loops with the fused loop.
-  rewriter.replaceAllUsesWith(target.getResults(),
-                              fusedLoop.getResults().slice(0, numTargetOuts));
-  rewriter.replaceAllUsesWith(
-      source.getResults(),
-      fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
-
-  // Erase the old loops.
-  rewriter.eraseOp(target);
-  rewriter.eraseOp(source);
+    rewriter.clone(op, mapping);
+
+  // Replace old loops by substituting their uses by results of the fused loop.
+  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+
+  return fusedLoop;
+}
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+                                                scf::ForOp source,
+                                                RewriterBase &rewriter) {
+  unsigned numTargetOuts = target.getNumResults();
+  unsigned numSourceOuts = source.getNumResults();
+
+  // Create fused init_args, with target's init_args before source's init_args.
+  SmallVector<Value> fusedInitArgs;
+  llvm::append_range(fusedInitArgs, target.getInitArgs());
+  llvm::append_range(fusedInitArgs, source.getInitArgs());
+
+  // Create a new scf.for op after the source loop (with scf.yield terminator
+  // (without arguments) only in case its init_args is empty).
+  rewriter.setInsertionPointAfter(source);
+  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+      source.getStep(), fusedInitArgs);
+
+  // Map original induction variables and operands to those of the fused loop.
+  IRMapping mapping;
+  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
+  mapping.map(target.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
+  mapping.map(source.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+
+  // Merge target's body into the new (fused) for loop and then source's body.
+  rewriter.setInsertionPointToStart(fusedLoop.getBody());
+  for (Operation &op : target.getBody()->without_terminator())
+    rewriter.clone(op, mapping);
+  for (Operation &op : source.getBody()->without_terminator())
+    rewriter.clone(op, mapping);
+
+  // Build fused yield results by appropriately mapping original yield operands.
+  SmallVector<Value> yieldResults;
+  for (Value operand : target.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  for (Value operand : source.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  if (yieldResults.size())
+    rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+
+  // Replace old loops by substituting their uses by results of the fused loop.
+  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
 
   return fusedLoop;
 }
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index faaa2db3aa57de..4155894c1e0a54 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -1,14 +1,113 @@
 // RUN: mlir-opt %s -transform-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s --check-prefix CHECK-NOCLEANUP
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+// CHECK: func.func @fuse_1st_for_into_2nd([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @fuse_1st_for_into_2nd(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[R0:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IA:%.*]] = [[A]], [[IB:%.*]] = [[B]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IA]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IA]][[[IV]]]
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IB]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB]][[[IV]]]
+    %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[R0:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB:%.*]] = [[B]], [[IA:%.*]] = [[A]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IB]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB]][[[IV]]]
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IA]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IA]][[[IV]]]
+    %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+  // NB: the dominance check used to fail on the following line,
+  // however the defining op for the value of %arg3 occurs above the source loop and hence is safe
+  // and %arg4 is a block argument of the scope of the loops and hence is safe
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK: func.func @matmul_fuse_1st_forall_into_2nd([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @matmul_fuse_1st_forall_into_2nd(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   // CHECK:   [[T:%.*]] = affine.apply
+  // CHECK:   tensor.extract_slice [[A2]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   [[OUT1:%.*]] = linalg.matmul
+  // CHECK:   tensor.extract_slice [[A1]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   [[OUT2:%.*]] = linalg.matmul
   // CHECK:   scf.forall.in_parallel {
@@ -16,12 +115,11 @@ func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tenso
   // CHECK:     tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1]
   // CHECK:   }
   // CHECK: }
-  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-  %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
 }
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
     %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -31,25 +129,37 @@ module attributes {transform.with_named_sequence} {
     %tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
-    %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+// CHECK: func.func @matmul_fuse_2nd_forall_into_1st([[A1:%.*]]: {{.*}}, [[A2:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @matmul_fuse_2nd_forall_into_1st(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
 
-  // expected-error @below {{user of results of target should be properly dominated by source}}
-  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-  %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  // CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  // CHECK:   [[T:%.*]] = affine.apply
+  // CHECK:   tensor.extract_slice [[A1]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   [[OUT1:%.*]] = linalg.matmul
+  // CHECK:   tensor.extract_slice [[A2]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   [[OUT2:%.*]] = linalg.matmul
+  // CHECK:   scf.forall.in_parallel {
+  // CHECK:     tensor.parallel_insert_slice [[OUT1]] into [[S1]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:     tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1]
+  // CHECK:   }
+  // CHECK: }
+  %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
 }
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
     %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -66,18 +176,84 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+// CHECK-NOCLEANUP: func.func @fuse_no_iter_args([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @fuse_no_iter_args(%A: tensor<128xf32>, %B: tensor<128xf32>) {
+  // CHECK-NOCLEANUP: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-NOCLEANUP: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-NOCLEANUP: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-NOCLEANUP: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK-NOCLEANUP: scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] {{.*}}
+  scf.for %arg0 = %c0 to %c128 step %c16 {
+  // CHECK-NOCLEANUP:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+    %2 = vector.transfer_read %A[%arg0], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    scf.yield
+  }
+  scf.for %arg0 = %c0 to %c128 step %c16 {
+  // CHECK-NOCLEANUP:   [[BSLICE:%.*]] = vector.transfer_read [[B]][[[IV]]], [[ZERO]]
+    %dup2 = vector.transfer_read %B[%arg0], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    scf.yield
+  }
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // expected-error @below {{user of results of target should be properly dominated by source}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %1) -> (tensor<128xf32>) {
+    %dup2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#0 into %for#1 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @source_forall_uses_result_of_target_forall_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
 
+  // expected-error @below {{user of results of target should be properly dominated by source}}
   %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
-  // expected-error @below {{values used inside regions of target should be properly dominated by source}}
   %out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
 }
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
     %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -87,25 +263,58 @@ module attributes {transform.with_named_sequence} {
     %tiled_mm1, %loop1 = transform.structured.tile_using_forall %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %tiled_mm2, %loop2 = transform.structured.tile_using_forall %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
 
-    %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    %fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
-  %zero = arith.constant 0.0 : f32
-  %out_alloc = tensor.empty() : tensor<128x128xf32>
-  %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %A) -> (tensor<128xf32>) {
+    %2 = vector.transfer_read %A[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
+  // expected-error @below {{values used inside regions of target should be properly dominated by source}}
+    %dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
 
-  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+// -----
+
+func.func @target_forall_depends_on_value_not_dominated_by_source_forall_err(%A1 : tensor<128x128xf32>, %A2 : tensor<128x128xf32>, %B : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  %zero = arith.constant 0.0 : f32
+  %buf1_alloc = tensor.empty() : tensor<128x128xf32>
+  %buf1 = linalg.fill ins(%zero : f32) outs(%buf1_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out1 = linalg.matmul ins(%A1, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%buf1 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out_alloc2 = tensor.empty() : tensor<128x128xf32>
+  %buf2 = linalg.fill ins(%zero : f32) outs(%buf1_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
   // expected-error @below {{operands of target should be properly dominated by source}}
-  %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out1 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A2, %B : tensor<128x128xf32>, tensor<128x128xf32>) outs(%buf2 : tensor<128x128xf32>) -> tensor<128x128xf32>
 
   func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
 }
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
     %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
@@ -118,4 +327,4 @@ module attributes {transform.with_named_sequence} {
     %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
     transform.yield
   }
-}
+}
\ No newline at end of file



More information about the Mlir-commits mailing list