[Mlir-commits] [mlir] [SCF][transform] Add support for scf.for in LoopFuseSibling op (PR #81495)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 12 08:20:51 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir
Author: Rolf Morel (rolfmorel)
<details>
<summary>Changes</summary>
Adds support for fusing two scf.for loops occurring in the same block. Implementation mirrors that of LoopFuseSibling's support for scf.forall, including only rudimentary checks, like the target loop's operands being dominated by the source loop.
Fixes a bug in the dominance check whereby it was checked that values in the target loop themselves dominated the source loop rather than (the ops) where these values originate.
Adds tests for using LoopFuseSibling on scf.for loops, including one which fails without the fix for the dominance check.
---
Patch is 23.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81495.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+6-4)
- (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+10)
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+33-10)
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+66)
- (modified) mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir (+185-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index cef73689c072b8..89d32ebcc24b10 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -342,11 +342,13 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
Fuses the `target` loop into the `source` loop assuming they are
independent of each other. It is the responsibility of the user to ensure
that the given two loops are independent of each other, this operation will
- not performa any legality checks and will simply fuse the two given loops.
+ not perform any legality checks and will simply fuse the two given loops.
- Currently, the only fusion supported is when both `target` and `source`
- are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
- mapping must match, otherwise a silencable failure is produced.
+ Currently, fusion is only supported in case both `target` and `source` are
+ `scf.for` operations or both are `scf.forall` operations. For `scf.for`
+ fusion the bounds and step size must match. For `scf.forall` fusion the
+ bounds and the mapping must match. Otherwise a silencable failure is
+ produced.
The input handles `target` and `source` must map to exactly one operation,
a definite failure is produced otherwise.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 9bdd6eb833876f..883d11bcc4df06 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForallOp source,
RewriterBase &rewriter);
+/// Given two scf.for loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
+ RewriterBase &rewriter);
+
} // namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index bc2fe5772af9d6..7056185aeb456d 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -441,8 +441,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
bool failed = false;
OpOperand *failedValue = nullptr;
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
- if (!domInfo.properlyDominates(operand->getOwner(), source,
- /*enclosingOpOk=*/false)) {
+ Operation *operandOp = operand->get().getDefiningOp();
+ if (operandOp && !domInfo.properlyDominates(operandOp, source,
+ /*enclosingOpOk=*/false)) {
+ // `operand` is not a block argument and its defining op does not
+ // dominate `source`
failed = true;
failedValue = operand;
}
@@ -476,15 +479,34 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
targetOp.getMapping() == sourceOp.getMapping();
}
-/// Fuse `target` into `source` assuming they are siblings and indepndent.
-/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
+static bool isForWithIdenticalConfiguration(Operation *target,
+ Operation *source) {
+ auto targetOp = dyn_cast<scf::ForOp>(target);
+ auto sourceOp = dyn_cast<scf::ForOp>(source);
+ if (!targetOp || !sourceOp)
+ return false;
+
+ return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+ targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+ targetOp.getStep() == sourceOp.getStep();
+}
+
+/// Fuse `target` into `source` assuming they are siblings and independent.
+/// TODO: Support fusion for operations besides scf.for and scf.forall.
static Operation *fuseSiblings(Operation *target, Operation *source,
RewriterBase &rewriter) {
- auto targetOp = dyn_cast<scf::ForallOp>(target);
- auto sourceOp = dyn_cast<scf::ForallOp>(source);
- if (!targetOp || !sourceOp)
- return nullptr;
- return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
+ auto targetForOp = dyn_cast<scf::ForOp>(target);
+ auto sourceForOp = dyn_cast<scf::ForOp>(source);
+ if (targetForOp && sourceForOp)
+ return fuseIndependentSiblingForLoops(targetForOp, sourceForOp, rewriter);
+
+ auto targetForallOp = dyn_cast<scf::ForallOp>(target);
+ auto sourceForallOp = dyn_cast<scf::ForallOp>(source);
+ if (targetForallOp && sourceForallOp)
+ return fuseIndependentSiblingForallLoops(targetForallOp, sourceForallOp,
+ rewriter);
+
+ return nullptr;
}
DiagnosedSilenceableFailure
@@ -511,7 +533,8 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
return diag;
// Check if the target can be fused into source.
- if (!isForallWithIdenticalConfiguration(target, source)) {
+ if (!isForallWithIdenticalConfiguration(target, source) &&
+ !isForWithIdenticalConfiguration(target, source)) {
return emitSilenceableFailure(target->getLoc())
<< "operations cannot be fused";
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index cdd85ddeb93add..f5836edf5eeb59 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -970,3 +970,69 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
return fusedLoop;
}
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+ scf::ForOp source,
+ RewriterBase &rewriter) {
+ // Create fused init_args.
+ auto targetInitArgs = target.getInitArgs();
+ auto sourceInitArgs = source.getInitArgs();
+ SmallVector<Value> fusedInitArgs;
+ fusedInitArgs.reserve(targetInitArgs.size() + sourceInitArgs.size());
+ fusedInitArgs.append(sourceInitArgs.begin(), sourceInitArgs.end());
+ fusedInitArgs.append(targetInitArgs.begin(), targetInitArgs.end());
+
+ // Create a new scf::for op after the source loop.
+ rewriter.setInsertionPointAfter(source);
+ scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+ source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+ source.getStep(), fusedInitArgs);
+
+ SmallVector<Value> yieldResults;
+
+ // First merge source loop into the new (fused) for loop and then target loop.
+ rewriter.setInsertionPointToEnd(fusedLoop.getBody());
+ for (auto loopAndInitArgsBegin :
+ {std::pair(source, (unsigned int)0),
+ std::pair(target, source.getNumRegionIterArgs())}) {
+ auto origLoop = loopAndInitArgsBegin.first;
+ IRMapping mapping;
+
+ mapping.map(origLoop.getInductionVar(), fusedLoop.getInductionVar());
+ for (size_t i = 0; i < origLoop.getNumRegionIterArgs(); ++i) {
+ mapping.map(
+ origLoop.getRegionIterArgs()[i],
+ fusedLoop.getRegionIterArgs()[loopAndInitArgsBegin.second + i]);
+ }
+
+ for (Operation &op : origLoop.getBody()->getOperations()) {
+ rewriter.clone(op, mapping);
+ }
+
+ if (origLoop.getNumResults() > 0) {
+ scf::YieldOp yieldFromOrigLoop =
+ cast<scf::YieldOp>(fusedLoop.getBody()->getTerminator());
+ yieldResults.append(yieldFromOrigLoop.getOperands().begin(),
+ yieldFromOrigLoop.getOperands().end());
+ rewriter.eraseOp(yieldFromOrigLoop);
+ }
+ }
+
+ // Construct combined YieldOp
+ rewriter.setInsertionPointToEnd(fusedLoop.getBody());
+ rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+
+ // Replace all uses of the old loops with the fused loop.
+ unsigned numSourceOuts = source.getNumResults();
+ rewriter.replaceAllUsesWith(source.getResults(),
+ fusedLoop.getResults().slice(0, numSourceOuts));
+ rewriter.replaceAllUsesWith(
+ target.getResults(),
+ fusedLoop.getResults().slice(numSourceOuts, target.getNumResults()));
+
+ // Erase the old loops.
+ rewriter.eraseOp(target);
+ rewriter.eraseOp(source);
+
+ return fusedLoop;
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index faaa2db3aa57de..332caf9cdf0516 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -transform-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_1st_forall_into_2nd(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
%zero = arith.constant 0.0 : f32
%out_alloc = tensor.empty() : tensor<128x128xf32>
%out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -38,7 +38,7 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_1st_forall_into_2nd_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
%zero = arith.constant 0.0 : f32
%out_alloc = tensor.empty() : tensor<128x128xf32>
%out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -66,7 +66,7 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_2nd_forall_into_1st_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
%zero = arith.constant 0.0 : f32
%out_alloc = tensor.empty() : tensor<128x128xf32>
%out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -94,7 +94,7 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+func.func @matmul_fuse_2nd_forall_into_1st_err(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
%zero = arith.constant 0.0 : f32
%out_alloc = tensor.empty() : tensor<128x128xf32>
%out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
@@ -119,3 +119,184 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK: func.func @test([[A:%.*]]: {{.*}}, [[B1:%.*]]: {{.*}}, [[B2:%.*]]: {{.*}} {{.*}}
+func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+ %zero = arith.constant 0.0 : f32
+ %out_alloc = tensor.empty() : tensor<128x128xf32>
+ %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+ // CHECK-DAG: [[C32:%.*]] = arith.constant 32 : index
+ // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+ // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: [[EMPTY:%.*]] = tensor.empty() : tensor<128x128xf32>
+ // CHECK-DAG: [[BUF:%.*]] = linalg.fill ins([[ZERO]] : {{.*}}) outs([[EMPTY]] : {{.*}}) {{.*}}
+ // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C32]] iter_args([[IA0:%.*]] = [[BUF]], [[IA1:%.*]] = [[BUF]]) {{.*}}
+ // CHECK-DAG: [[ASLICE:%.*]] = tensor.extract_slice [[A]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK-DAG: [[SLICE0:%.*]] = tensor.extract_slice [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK: [[OUT1:%.*]] = linalg.matmul ins([[ASLICE]], [[B1]] : {{.*}}) outs([[SLICE0]]
+ // CHECK-NEXT: [[INS0:%.*]] = tensor.insert_slice [[OUT1]] into [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK-DAG: [[SLICE1:%.*]] = tensor.extract_slice [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK: [[OUT2:%.*]] = linalg.matmul ins([[ASLICE]], [[B2]] : {{.*}}) outs([[SLICE1]]
+ // CHECK-NEXT: [[INS1:%.*]] = tensor.insert_slice [[OUT2]] into [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK: scf.yield [[INS0]], [[INS1]] : {{.*}}
+ %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ // CHECK: return [[RST]]#0, [[RST]]#1 : {{.*}}
+ func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
+ %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+ %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %tiled_mm1, %loop1 = transform.structured.tile_using_for %mm1 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %tiled_mm2, %loop2 = transform.structured.tile_using_for %mm2 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ //transform.print %variant_op : !transform.any_op
+
+ %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+ %cse_func = transform.apply_registered_pass "cse" to %func : (!transform.any_op) -> (!transform.any_op)
+ %for_loops = transform.structured.match ops{["scf.for"]} in %cse_func : (!transform.any_op) -> (!transform.any_op)
+ %for_loop1, %for_loop2 = transform.split_handle %for_loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused_loop = transform.loop.fuse_sibling %for_loop2 into %for_loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: func.func @test([[A:%.*]]: {{.*}}, [[B1:%.*]]: {{.*}}, [[B2:%.*]]: {{.*}} {{.*}}
+func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+ %zero = arith.constant 0.0 : f32
+ %out_alloc = tensor.empty() : tensor<128x128xf32>
+ %out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+ // CHECK-DAG: [[C32:%.*]] = arith.constant 32 : index
+ // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+ // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: [[EMPTY:%.*]] = tensor.empty() : tensor<128x128xf32>
+ // CHECK-DAG: [[BUF:%.*]] = linalg.fill ins([[ZERO]] : {{.*}}) outs([[EMPTY]] : {{.*}}) {{.*}}
+ // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C32]] iter_args([[IA0:%.*]] = [[BUF]], [[IA1:%.*]] = [[BUF]]) {{.*}}
+ // CHECK-DAG: [[ASLICE:%.*]] = tensor.extract_slice [[A]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK-DAG: [[SLICE0:%.*]] = tensor.extract_slice [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK: [[OUT1:%.*]] = linalg.matmul ins([[ASLICE]], [[B2]] : {{.*}}) outs([[SLICE0]]
+ // CHECK-NEXT: [[INS0:%.*]] = tensor.insert_slice [[OUT1]] into [[IA0]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK-DAG: [[SLICE1:%.*]] = tensor.extract_slice [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK: [[OUT2:%.*]] = linalg.matmul ins([[ASLICE]], [[B1]] : {{.*}}) outs([[SLICE1]]
+ // CHECK-NEXT: [[INS1:%.*]] = tensor.insert_slice [[OUT2]] into [[IA1]][[[IV]], 0] [32, 128] [1, 1]
+ // CHECK: scf.yield [[INS0]], [[INS1]] : {{.*}}
+ %out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
+
+ // CHECK: return [[RST]]#1, [[RST]]#0 : {{.*}}
+ func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%variant_op : !transform.any_op {transform.readonly}) {
+ %matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+
+ %mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %tiled_mm1, %loop1 = transform.structured.tile_using_for %mm1 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %tiled_mm2, %loop2 = transform.structured.tile_using_for %mm2 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
+ %cse_func = transform.apply_registered_pass "cse" to %func : (!transform.any_op) -> (!transform.any_op)
+ %for_loops = transform.structured.match ops{["scf.for"]} in %cse_func : (!transform.any_op) -> (!transform.any_op)
+ %for_loop1, %for_loop2 = transform.split_handle %for_loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused_loop = transform.loop.fuse_sibling %for_loop1 into %for_loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// transform.loop.fuse_sibling used to silently fail on the following due to a bug in the dominance check
+
+// CHECK: func.func @no_dominance_bug([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @no_dominance_bug(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+ // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+ // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+ // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB0:%.*]] = [[B]], [[IB1:%.*]] = [[B]]) {{.*}}
+ %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+ // CHECK-DAG: [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+ // CHECK: [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+ // CHECK-NEXT: [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+ %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+ %5 = arith.addf %3, %2 : vector<16xf32>
+ %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+ scf.yield %6 : tensor<128xf32>
+ }
+ %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+ // CHECK-DAG: [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+ // CHECK...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81495
More information about the Mlir-commits
mailing list