[Mlir-commits] [mlir] [SCF][Transform] Add support for scf.for in LoopFuseSibling op (PR #81495)

Rolf Morel llvmlistbot at llvm.org
Mon Mar 18 01:36:58 PDT 2024


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/81495

>From dd30ea491309eb2454539b86d6fcfc4d23fff5d6 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at huawei.com>
Date: Tue, 30 Jan 2024 00:14:57 +0800
Subject: [PATCH] [SCF][Transform] Add support for scf.for in LoopFuseSibling
 op

Adds support for fusing two scf.for loops occurring in the same block.
Implementation mirrors that of LoopFuseSibling's support for scf.forall,
including only rudimentary checks, like the target loop's operands being
dominated by the source loop.

Fixes a bug in the dominance check whereby it was checked that values in the
target loop themselves dominated the source loop rather than (the ops) where
these values originate.

Adds tests for using LoopFuseSibling on scf.for loops, including one which
fails without the fix for the dominance check.
---
 .../SCF/TransformOps/SCFTransformOps.td       |  12 +-
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  10 +
 .../SCF/TransformOps/SCFTransformOps.cpp      |  51 ++--
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |  93 +++++---
 .../SCF/transform-loop-fuse-sibling.mlir      | 222 +++++++++++++++++-
 5 files changed, 337 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index cef73689c072b8..a48bc898b8d478 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -333,7 +333,7 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
   }];
 }
 
-def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
+def LoopFuseSiblingOp : Op<Transform_Dialect, "loop.fuse_sibling",
   [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
    DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let summary = "Fuse a loop into another loop, assuming the fusion is legal.";
@@ -342,11 +342,13 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
     Fuses the `target` loop into the `source` loop assuming they are
     independent of each other. It is the responsibility of the user to ensure
     that the given two loops are independent of each other, this operation will
-    not performa any legality checks and will simply fuse the two given loops.
+    not perform any legality checks and will simply fuse the two given loops.
 
-    Currently, the only fusion supported is when both `target` and `source`
-    are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
-    mapping must match, otherwise a silencable failure is produced.
+    Currently, fusion is only supported in case both `target` and `source` are
+    `scf.for` operations or both are `scf.forall` operations. For `scf.for`
+    fusion the bounds and step size must match. For `scf.forall` fusion the
+    bounds and the mapping must match. Otherwise a silencable failure is
+    produced.
 
     The input handles `target` and `source` must map to exactly one operation,
     a definite failure is produced otherwise.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 9bdd6eb833876f..883d11bcc4df06 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                 scf::ForallOp source,
                                                 RewriterBase &rewriter);
 
+/// Given two scf.for loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
+                                          RewriterBase &rewriter);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index bc2fe5772af9d6..e8fcd17b6f0814 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -384,7 +384,7 @@ void transform::TakeAssumedBranchOp::getEffects(
 }
 
 //===----------------------------------------------------------------------===//
-// LoopFuseSibling
+// LoopFuseSiblingOp
 //===----------------------------------------------------------------------===//
 
 /// Check if `target` and `source` are siblings, in the context that `target`
@@ -441,8 +441,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
     bool failed = false;
     OpOperand *failedValue = nullptr;
     visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
-      if (!domInfo.properlyDominates(operand->getOwner(), source,
-                                     /*enclosingOpOk=*/false)) {
+      Operation *operandOp = operand->get().getDefiningOp();
+      if (operandOp && !domInfo.properlyDominates(operandOp, source,
+                                                  /*enclosingOpOk=*/false)) {
+        // `operand` is not a block argument and its defining op does not
+        // dominate `source`
         failed = true;
         failedValue = operand;
       }
@@ -476,21 +479,40 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
          targetOp.getMapping() == sourceOp.getMapping();
 }
 
-/// Fuse `target` into `source` assuming they are siblings and indepndent.
-/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
+static bool isForWithIdenticalConfiguration(Operation *target,
+                                            Operation *source) {
+  auto targetOp = dyn_cast<scf::ForOp>(target);
+  auto sourceOp = dyn_cast<scf::ForOp>(source);
+  if (!targetOp || !sourceOp)
+    return false;
+
+  return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+         targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+         targetOp.getStep() == sourceOp.getStep();
+}
+
+/// Fuse `target` into `source` assuming they are siblings and independent.
+/// TODO: Support fusion for operations besides scf.for and scf.forall.
 static Operation *fuseSiblings(Operation *target, Operation *source,
                                RewriterBase &rewriter) {
-  auto targetOp = dyn_cast<scf::ForallOp>(target);
-  auto sourceOp = dyn_cast<scf::ForallOp>(source);
-  if (!targetOp || !sourceOp)
-    return nullptr;
-  return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
+  auto targetForOp = dyn_cast<scf::ForOp>(target);
+  auto sourceForOp = dyn_cast<scf::ForOp>(source);
+  if (targetForOp && sourceForOp)
+    return fuseIndependentSiblingForLoops(targetForOp, sourceForOp, rewriter);
+
+  auto targetForallOp = dyn_cast<scf::ForallOp>(target);
+  auto sourceForallOp = dyn_cast<scf::ForallOp>(source);
+  if (targetForallOp && sourceForallOp)
+    return fuseIndependentSiblingForallLoops(targetForallOp, sourceForallOp,
+                                             rewriter);
+
+  return nullptr;
 }
 
 DiagnosedSilenceableFailure
-transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
-                                  transform::TransformResults &results,
-                                  transform::TransformState &state) {
+transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
   auto targetOps = state.getPayloadOps(getTarget());
   auto sourceOps = state.getPayloadOps(getSource());
 
@@ -511,7 +533,8 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
     return diag;
 
   // Check if the target can be fused into source.
-  if (!isForallWithIdenticalConfiguration(target, source)) {
+  if (!isForallWithIdenticalConfiguration(target, source) &&
+      !isForWithIdenticalConfiguration(target, source)) {
     return emitSilenceableFailure(target->getLoc())
            << "operations cannot be fused";
   }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index cdd85ddeb93add..c4f78d68f1c90d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -912,39 +912,34 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
   unsigned numTargetOuts = target.getNumResults();
   unsigned numSourceOuts = source.getNumResults();
 
-  OperandRange targetOuts = target.getOutputs();
-  OperandRange sourceOuts = source.getOutputs();
-
   // Create fused shared_outs.
   SmallVector<Value> fusedOuts;
-  fusedOuts.reserve(numTargetOuts + numSourceOuts);
-  fusedOuts.append(targetOuts.begin(), targetOuts.end());
-  fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
+  llvm::append_range(fusedOuts, target.getOutputs());
+  llvm::append_range(fusedOuts, source.getOutputs());
 
-  // Create a new scf::forall op after the source loop.
+  // Create a new scf.forall op after the source loop.
   rewriter.setInsertionPointAfter(source);
   scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
       source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
       source.getMixedStep(), fusedOuts, source.getMapping());
 
   // Map control operands.
-  IRMapping fusedMapping;
-  fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
-  fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+  IRMapping mapping;
+  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
+  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
 
   // Map shared outs.
-  fusedMapping.map(target.getRegionIterArgs(),
-                   fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
-  fusedMapping.map(
-      source.getRegionIterArgs(),
-      fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
+  mapping.map(target.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+  mapping.map(source.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
 
   // Append everything except the terminator into the fused operation.
   rewriter.setInsertionPointToStart(fusedLoop.getBody());
   for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
   for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
 
   // Fuse the old terminator in_parallel ops into the new one.
   scf::InParallelOp targetTerm = target.getTerminator();
@@ -953,20 +948,62 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
 
   rewriter.setInsertionPointToStart(fusedTerm.getBody());
   for (Operation &op : targetTerm.getYieldingOps())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
   for (Operation &op : sourceTerm.getYieldingOps())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
 
   // Replace all uses of the old loops with the fused loop.
-  rewriter.replaceAllUsesWith(target.getResults(),
-                              fusedLoop.getResults().slice(0, numTargetOuts));
-  rewriter.replaceAllUsesWith(
-      source.getResults(),
-      fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
-
-  // Erase the old loops.
-  rewriter.eraseOp(target);
-  rewriter.eraseOp(source);
+  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+
+  return fusedLoop;
+}
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+                                                scf::ForOp source,
+                                                RewriterBase &rewriter) {
+  unsigned numTargetOuts = target.getNumResults();
+  unsigned numSourceOuts = source.getNumResults();
+
+  // Create fused init_args, with target's init_args before source's init_args.
+  SmallVector<Value> fusedInitArgs;
+  llvm::append_range(fusedInitArgs, target.getInitArgs());
+  llvm::append_range(fusedInitArgs, source.getInitArgs());
+
+  // Create a new scf.for op after the source loop.
+  rewriter.setInsertionPointAfter(source);
+  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+      source.getStep(), fusedInitArgs);
+
+  // Map original induction variables and operands to those of the fused loop.
+  IRMapping mapping;
+  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
+  mapping.map(target.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+
+  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
+  mapping.map(source.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+
+  // Merge target's body into the new (fused) for loop and then source's body.
+  rewriter.setInsertionPointToStart(fusedLoop.getBody());
+  for (Operation &op : target.getBody()->without_terminator())
+    rewriter.clone(op, mapping);
+  for (Operation &op : source.getBody()->without_terminator())
+    rewriter.clone(op, mapping);
+
+  // Build fused yield results by appropriately mapping original yield operands.
+  SmallVector<Value> yieldResults;
+  for (Value operand : target.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  for (Value operand : source.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+
+  // Replace old loops by substituting their uses by results of the fused loop.
+  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
 
   return fusedLoop;
 }
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index faaa2db3aa57de..67bc79bd4bbf81 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -transform-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_1st_forall_into_2nd(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -38,7 +38,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_1st_forall_into_2nd_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -66,7 +66,41 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_2nd_for_into_1st_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+    %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+    %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<32xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32>
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
+    %matched = transform.structured.match ops{["scf.for"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+    %first_loop, %second_loop = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+
+
+// -----
+func.func @matmul_fuse_2nd_forall_into_1st_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -94,7 +128,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_2nd_forall_into_1st_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -119,3 +153,183 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK: func.func @matmul_fuse_1st_for_into_2nd([[A:%.*]]: {{.*}}, [[B1:%.*]]: {{.*}}, [[B2:%.*]]: {{.*}} {{.*}}
+func.func @matmul_fuse_1st_for_into_2nd(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  %zero = arith.constant 0.0 : f32
+  %out_alloc = tensor.empty() : tensor<128x128xf32>
+  %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C32:%.*]] = arith.constant 32 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: [[EMPTY:%.*]] = tensor.empty() : tensor<128x128xf32>
+  // CHECK-DAG: [[BUF:%.*]] = linalg.fill ins([[ZERO]] : {{.*}}) outs([[EMPTY]] : {{.*}}) {{.*}}
+  // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C32]] iter_args([[IA0:%.*]] = [[BUF]], [[IA1:%.*]] = [[BUF]]) {{.*}}
+  // CHECK-DAG:   [[ASLICE:%.*]] = tensor.extract_slice [[A]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK-DAG:   [[SLICE0:%.*]] = tensor.extract_slice [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK:       [[OUT1:%.*]] = linalg.matmul ins([[ASLICE]], [[B1]] : {{.*}}) outs([[SLICE0]]
+  // CHECK-NEXT:  [[INS0:%.*]] = tensor.insert_slice [[OUT1]] into [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK-DAG:   [[SLICE1:%.*]] = tensor.extract_slice [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK:       [[OUT2:%.*]] = linalg.matmul ins([[ASLICE]], [[B2]] : {{.*}}) outs([[SLICE1]]
+  // CHECK-NEXT:  [[INS1:%.*]] = tensor.insert_slice [[OUT2]] into [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK: scf.yield [[INS0]], [[INS1]] : {{.*}}
+  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+  // CHECK: return [[RST]]#0, [[RST]]#1 : {{.*}}
+  func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
+    %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+    %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %tiled_mm1, %loop1 = transform.structured.tile_using_for %mm1 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %tiled_mm2, %loop2 = transform.structured.tile_using_for %mm2 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+    %cse_func = transform.apply_registered_pass "cse" to %func : (!transform.any_op) -> (!transform.any_op)
+    %for_loops = transform.structured.match ops{["scf.for"]} in %cse_func : (!transform.any_op) -> (!transform.any_op)
+    %for_loop1, %for_loop2 = transform.split_handle %for_loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused_loop = transform.loop.fuse_sibling %for_loop2 into %for_loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK: func.func @matmul_fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B1:%.*]]: {{.*}}, [[B2:%.*]]: {{.*}} {{.*}}
+func.func @matmul_fuse_2nd_for_into_1st(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  %zero = arith.constant 0.0 : f32
+  %out_alloc = tensor.empty() : tensor<128x128xf32>
+  %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C32:%.*]] = arith.constant 32 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: [[EMPTY:%.*]] = tensor.empty() : tensor<128x128xf32>
+  // CHECK-DAG: [[BUF:%.*]] = linalg.fill ins([[ZERO]] : {{.*}}) outs([[EMPTY]] : {{.*}}) {{.*}}
+  // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C32]] iter_args([[IA0:%.*]] = [[BUF]], [[IA1:%.*]] = [[BUF]]) {{.*}}
+  // CHECK-DAG:   [[ASLICE:%.*]] = tensor.extract_slice [[A]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK-DAG:   [[SLICE0:%.*]] = tensor.extract_slice [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK:       [[OUT1:%.*]] = linalg.matmul ins([[ASLICE]], [[B2]] : {{.*}}) outs([[SLICE0]]
+  // CHECK-NEXT:  [[INS0:%.*]] = tensor.insert_slice [[OUT1]] into [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK-DAG:   [[SLICE1:%.*]] = tensor.extract_slice [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK:       [[OUT2:%.*]] = linalg.matmul ins([[ASLICE]], [[B1]] : {{.*}}) outs([[SLICE1]]
+  // CHECK-NEXT:  [[INS1:%.*]] = tensor.insert_slice [[OUT2]] into [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK: scf.yield [[INS0]], [[INS1]] : {{.*}}
+  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+  // CHECK: return [[RST]]#1, [[RST]]#0 : {{.*}}
+  func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
+    %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+    %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %tiled_mm1, %loop1 = transform.structured.tile_using_for %mm1 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %tiled_mm2, %loop2 = transform.structured.tile_using_for %mm2 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+    %cse_func = transform.apply_registered_pass "cse" to %func : (!transform.any_op) -> (!transform.any_op)
+    %for_loops = transform.structured.match ops{["scf.for"]} in %cse_func : (!transform.any_op) -> (!transform.any_op)
+    %for_loop1, %for_loop2 = transform.split_handle %for_loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused_loop = transform.loop.fuse_sibling %for_loop1 into %for_loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// transform.loop.fuse_sibling used to silently fail on the following due to a bug in the dominance check
+
+// CHECK: func.func @no_dominance_bug([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @no_dominance_bug(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB0:%.*]] = [[B]], [[IB1:%.*]] = [[B]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+    %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+    %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @dominance_check_violation(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+    %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %c32 = arith.constant 32 : index // When fusing the subsequent loop into the prior loop, this value is used before its defined.
+  // expected-error @below {{operands of target should be properly dominated by source}}
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+    %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<32xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32>
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list