[Mlir-commits] [mlir] [SCF][transform] Add support for scf.for in LoopFuseSibling op (PR #81495)

Rolf Morel llvmlistbot at llvm.org
Mon Feb 12 08:30:57 PST 2024


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/81495

>From 51d524a4a2a57fd2d9c68d284ab6f99eaf05f06f Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at huawei.com>
Date: Tue, 30 Jan 2024 00:14:57 +0800
Subject: [PATCH] [SCF][Transform] Add support for scf.for in LoopFuseSibling
 op

Adds support for fusing two scf.for loops occurring in the same block.
Implementation mirrors that of LoopFuseSibling's support for scf.forall,
including only rudimentary checks, like the target loop's operands being
dominated by the source loop.

Fixes a bug in the dominance check whereby it was checked that values in the
target loop themselves dominated the source loop rather than (the ops) where
these values originate.

Adds tests for using LoopFuseSibling on scf.for loops, including one which
fails without the fix for the dominance check.
---
 .../SCF/TransformOps/SCFTransformOps.td       |  10 +-
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |  10 +
 .../SCF/TransformOps/SCFTransformOps.cpp      |  43 +++-
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |  66 ++++++
 .../SCF/transform-loop-fuse-sibling.mlir      | 189 +++++++++++++++++-
 5 files changed, 300 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index cef73689c072b8..89d32ebcc24b10 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -342,11 +342,13 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
     Fuses the `target` loop into the `source` loop assuming they are
     independent of each other. It is the responsibility of the user to ensure
     that the given two loops are independent of each other, this operation will
-    not performa any legality checks and will simply fuse the two given loops.
+    not perform any legality checks and will simply fuse the two given loops.
 
-    Currently, the only fusion supported is when both `target` and `source`
-    are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
-    mapping must match, otherwise a silencable failure is produced.
+    Currently, fusion is only supported in case both `target` and `source` are
+    `scf.for` operations or both are `scf.forall` operations. For `scf.for`
+    fusion the bounds and step size must match. For `scf.forall` fusion the
+    bounds and the mapping must match. Otherwise a silencable failure is
+    produced.
 
     The input handles `target` and `source` must map to exactly one operation,
     a definite failure is produced otherwise.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 9bdd6eb833876f..883d11bcc4df06 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                 scf::ForallOp source,
                                                 RewriterBase &rewriter);
 
+/// Given two scf.for loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
+                                          RewriterBase &rewriter);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index bc2fe5772af9d6..7056185aeb456d 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -441,8 +441,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
     bool failed = false;
     OpOperand *failedValue = nullptr;
     visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
-      if (!domInfo.properlyDominates(operand->getOwner(), source,
-                                     /*enclosingOpOk=*/false)) {
+      Operation *operandOp = operand->get().getDefiningOp();
+      if (operandOp && !domInfo.properlyDominates(operandOp, source,
+                                                  /*enclosingOpOk=*/false)) {
+        // `operand` is not a block argument and its defining op does not
+        // dominate `source`
         failed = true;
         failedValue = operand;
       }
@@ -476,15 +479,34 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
          targetOp.getMapping() == sourceOp.getMapping();
 }
 
-/// Fuse `target` into `source` assuming they are siblings and indepndent.
-/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
+static bool isForWithIdenticalConfiguration(Operation *target,
+                                            Operation *source) {
+  auto targetOp = dyn_cast<scf::ForOp>(target);
+  auto sourceOp = dyn_cast<scf::ForOp>(source);
+  if (!targetOp || !sourceOp)
+    return false;
+
+  return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+         targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+         targetOp.getStep() == sourceOp.getStep();
+}
+
+/// Fuse `target` into `source` assuming they are siblings and independent.
+/// TODO: Support fusion for operations besides scf.for and scf.forall.
 static Operation *fuseSiblings(Operation *target, Operation *source,
                                RewriterBase &rewriter) {
-  auto targetOp = dyn_cast<scf::ForallOp>(target);
-  auto sourceOp = dyn_cast<scf::ForallOp>(source);
-  if (!targetOp || !sourceOp)
-    return nullptr;
-  return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
+  auto targetForOp = dyn_cast<scf::ForOp>(target);
+  auto sourceForOp = dyn_cast<scf::ForOp>(source);
+  if (targetForOp && sourceForOp)
+    return fuseIndependentSiblingForLoops(targetForOp, sourceForOp, rewriter);
+
+  auto targetForallOp = dyn_cast<scf::ForallOp>(target);
+  auto sourceForallOp = dyn_cast<scf::ForallOp>(source);
+  if (targetForallOp && sourceForallOp)
+    return fuseIndependentSiblingForallLoops(targetForallOp, sourceForallOp,
+                                             rewriter);
+
+  return nullptr;
 }
 
 DiagnosedSilenceableFailure
@@ -511,7 +533,8 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
     return diag;
 
   // Check if the target can be fused into source.
-  if (!isForallWithIdenticalConfiguration(target, source)) {
+  if (!isForallWithIdenticalConfiguration(target, source) &&
+      !isForWithIdenticalConfiguration(target, source)) {
     return emitSilenceableFailure(target->getLoc())
            << "operations cannot be fused";
   }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index cdd85ddeb93add..f5836edf5eeb59 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -970,3 +970,69 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
 
   return fusedLoop;
 }
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+                                                scf::ForOp source,
+                                                RewriterBase &rewriter) {
+  // Create fused init_args.
+  auto targetInitArgs = target.getInitArgs();
+  auto sourceInitArgs = source.getInitArgs();
+  SmallVector<Value> fusedInitArgs;
+  fusedInitArgs.reserve(targetInitArgs.size() + sourceInitArgs.size());
+  fusedInitArgs.append(sourceInitArgs.begin(), sourceInitArgs.end());
+  fusedInitArgs.append(targetInitArgs.begin(), targetInitArgs.end());
+
+  // Create a new scf::for op after the source loop.
+  rewriter.setInsertionPointAfter(source);
+  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+      source.getStep(), fusedInitArgs);
+
+  SmallVector<Value> yieldResults;
+
+  // First merge source loop into the new (fused) for loop and then target loop.
+  rewriter.setInsertionPointToEnd(fusedLoop.getBody());
+  for (auto loopAndInitArgsBegin :
+       {std::pair(source, (unsigned int)0),
+        std::pair(target, source.getNumRegionIterArgs())}) {
+    auto origLoop = loopAndInitArgsBegin.first;
+    IRMapping mapping;
+
+    mapping.map(origLoop.getInductionVar(), fusedLoop.getInductionVar());
+    for (size_t i = 0; i < origLoop.getNumRegionIterArgs(); ++i) {
+      mapping.map(
+          origLoop.getRegionIterArgs()[i],
+          fusedLoop.getRegionIterArgs()[loopAndInitArgsBegin.second + i]);
+    }
+
+    for (Operation &op : origLoop.getBody()->getOperations()) {
+      rewriter.clone(op, mapping);
+    }
+
+    if (origLoop.getNumResults() > 0) {
+      scf::YieldOp yieldFromOrigLoop =
+          cast<scf::YieldOp>(fusedLoop.getBody()->getTerminator());
+      yieldResults.append(yieldFromOrigLoop.getOperands().begin(),
+                          yieldFromOrigLoop.getOperands().end());
+      rewriter.eraseOp(yieldFromOrigLoop);
+    }
+  }
+
+  // Construct combined YieldOp
+  rewriter.setInsertionPointToEnd(fusedLoop.getBody());
+  rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+
+  // Replace all uses of the old loops with the fused loop.
+  unsigned numSourceOuts = source.getNumResults();
+  rewriter.replaceAllUsesWith(source.getResults(),
+                              fusedLoop.getResults().slice(0, numSourceOuts));
+  rewriter.replaceAllUsesWith(
+      target.getResults(),
+      fusedLoop.getResults().slice(numSourceOuts, target.getNumResults()));
+
+  // Erase the old loops.
+  rewriter.eraseOp(target);
+  rewriter.eraseOp(source);
+
+  return fusedLoop;
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index faaa2db3aa57de..332caf9cdf0516 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -transform-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_1st_forall_into_2nd(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -38,7 +38,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_1st_forall_into_2nd_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -66,7 +66,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_2nd_forall_into_1st_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -94,7 +94,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_2nd_forall_into_1st_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
   %zero = arith.constant 0.0 : f32
   %out_alloc = tensor.empty() : tensor<128x128xf32>
   %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -119,3 +119,184 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK: func.func @test([[A:%.*]]: {{.*}}, [[B1:%.*]]: {{.*}}, [[B2:%.*]]: {{.*}} {{.*}}
+func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  %zero = arith.constant 0.0 : f32
+  %out_alloc = tensor.empty() : tensor<128x128xf32>
+  %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C32:%.*]] = arith.constant 32 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: [[EMPTY:%.*]] = tensor.empty() : tensor<128x128xf32>
+  // CHECK-DAG: [[BUF:%.*]] = linalg.fill ins([[ZERO]] : {{.*}}) outs([[EMPTY]] : {{.*}}) {{.*}}
+  // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C32]] iter_args([[IA0:%.*]] = [[BUF]], [[IA1:%.*]] = [[BUF]]) {{.*}}
+  // CHECK-DAG:   [[ASLICE:%.*]] = tensor.extract_slice [[A]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK-DAG:   [[SLICE0:%.*]] = tensor.extract_slice [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK:       [[OUT1:%.*]] = linalg.matmul ins([[ASLICE]], [[B1]] : {{.*}}) outs([[SLICE0]]
+  // CHECK-NEXT:  [[INS0:%.*]] = tensor.insert_slice [[OUT1]] into [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK-DAG:   [[SLICE1:%.*]] = tensor.extract_slice [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK:       [[OUT2:%.*]] = linalg.matmul ins([[ASLICE]], [[B2]] : {{.*}}) outs([[SLICE1]]
+  // CHECK-NEXT:  [[INS1:%.*]] = tensor.insert_slice [[OUT2]] into [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK: scf.yield [[INS0]], [[INS1]] : {{.*}}
+  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+  // CHECK: return [[RST]]#0, [[RST]]#1 : {{.*}}
+  func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
+    %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+    %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %tiled_mm1, %loop1 = transform.structured.tile_using_for %mm1 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %tiled_mm2, %loop2 = transform.structured.tile_using_for %mm2 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    //transform.print %variant_op : !transform.any_op
+
+    %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+    %cse_func = transform.apply_registered_pass "cse" to %func : (!transform.any_op) -> (!transform.any_op)
+    %for_loops = transform.structured.match ops{["scf.for"]} in %cse_func : (!transform.any_op) -> (!transform.any_op)
+    %for_loop1, %for_loop2 = transform.split_handle %for_loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused_loop = transform.loop.fuse_sibling %for_loop2 into %for_loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK: func.func @test([[A:%.*]]: {{.*}}, [[B1:%.*]]: {{.*}}, [[B2:%.*]]: {{.*}} {{.*}}
+func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  %zero = arith.constant 0.0 : f32
+  %out_alloc = tensor.empty() : tensor<128x128xf32>
+  %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C32:%.*]] = arith.constant 32 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: [[EMPTY:%.*]] = tensor.empty() : tensor<128x128xf32>
+  // CHECK-DAG: [[BUF:%.*]] = linalg.fill ins([[ZERO]] : {{.*}}) outs([[EMPTY]] : {{.*}}) {{.*}}
+  // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C32]] iter_args([[IA0:%.*]] = [[BUF]], [[IA1:%.*]] = [[BUF]]) {{.*}}
+  // CHECK-DAG:   [[ASLICE:%.*]] = tensor.extract_slice [[A]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK-DAG:   [[SLICE0:%.*]] = tensor.extract_slice [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK:       [[OUT1:%.*]] = linalg.matmul ins([[ASLICE]], [[B2]] : {{.*}}) outs([[SLICE0]]
+  // CHECK-NEXT:  [[INS0:%.*]] = tensor.insert_slice [[OUT1]] into [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK-DAG:   [[SLICE1:%.*]] = tensor.extract_slice [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK:       [[OUT2:%.*]] = linalg.matmul ins([[ASLICE]], [[B1]] : {{.*}}) outs([[SLICE1]]
+  // CHECK-NEXT:  [[INS1:%.*]] = tensor.insert_slice [[OUT2]] into [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+  // CHECK: scf.yield [[INS0]], [[INS1]] : {{.*}}
+  %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+  // CHECK: return [[RST]]#1, [[RST]]#0 : {{.*}}
+  func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
+    %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+    %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %tiled_mm1, %loop1 = transform.structured.tile_using_for %mm1 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %tiled_mm2, %loop2 = transform.structured.tile_using_for %mm2 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+    %cse_func = transform.apply_registered_pass "cse" to %func : (!transform.any_op) -> (!transform.any_op)
+    %for_loops = transform.structured.match ops{["scf.for"]} in %cse_func : (!transform.any_op) -> (!transform.any_op)
+    %for_loop1, %for_loop2 = transform.split_handle %for_loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused_loop = transform.loop.fuse_sibling %for_loop1 into %for_loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// transform.loop.fuse_sibling used to silently fail on the following due to a bug in the dominance check
+
+// CHECK: func.func @no_dominance_bug([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @no_dominance_bug(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB0:%.*]] = [[B]], [[IB1:%.*]] = [[B]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+    %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+    %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @dominance_check_violation(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+    %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  }
+  %c32 = arith.constant 32 : index
+  // expected-error @below {{operands of target should be properly dominated by source}}
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+    %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<32xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32>
+    scf.yield %dup6 : tensor<128xf32>
+  }
+  return %1, %dup1 : tensor<128xf32>, tensor<128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for:2 = transform.split_handle %0 :  (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %fused = transform.loop.fuse_sibling %for#1 into %for#0 : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list