[Mlir-commits] [mlir] 70ded9b - [mlir] Add support for multiple uses in transform.structured.fuse_into_containing_op
Harsh Menon
llvmlistbot at llvm.org
Wed May 24 08:43:45 PDT 2023
Author: Harsh Menon
Date: 2023-05-24T08:36:41-07:00
New Revision: 70ded9bc86a974aec3545c9fce32059a4179b5cf
URL: https://github.com/llvm/llvm-project/commit/70ded9bc86a974aec3545c9fce32059a4179b5cf
DIFF: https://github.com/llvm/llvm-project/commit/70ded9bc86a974aec3545c9fce32059a4179b5cf.diff
LOG: [mlir] Add support for multiple uses in transform.structured.fuse_into_containing_op
In the tile and fuse of the first extract use, we add support
for scenarios where the results of the tiled op have uses
that are dominated by the scf.for_all. Specifically, we replace
the scf.for_all with a new scf.for_all that has an additional
shared_out and add the appropriate parallel insert slice op.
Differential Revision: https://reviews.llvm.org/D151275
Added:
Modified:
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c1a24a4e04add..4f476d1053827 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -347,10 +347,82 @@ void transform::FuseIntoContainingOp::build(OpBuilder &builder,
result.addTypes(transform::AnyOpType::get(builder.getContext()));
}
+/// Add new operands to the forall op for users of the producerOp
+/// that are dominated by the containing scf.forall op.
+static Operation *replaceForAllWithNewSignature(
+ RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
+ Operation *containingOp, TilingResult &tileAndFuseResult,
+ int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
+ SmallVector<OpFoldResult> &sizes) {
+
+ // Count number of users not including the containing op
+ SetVector<Operation *> dominatedUsers;
+ DominanceInfo domInfo(containingOp);
+ for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
+ if ((user != containingOp) && (domInfo.dominates(containingOp, user))) {
+ dominatedUsers.insert(user);
+ }
+ }
+ if (dominatedUsers.size() == 0)
+ return nullptr;
+
+ // Create new scf.forall op
+ auto forallOp = cast<scf::ForallOp>(containingOp);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(forallOp);
+
+ // Get new output
+ Location loc = forallOp.getLoc();
+ auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
+ if (!genericOp)
+ return nullptr;
+ SmallVector<Value> outputs = genericOp.getOutputs();
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.push_back(outputs[resultNumber]);
+
+ // Create new scf.forall op
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ rewriter.eraseBlock(newforallOp.getBody());
+ newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+ // Add additional block argument for new value being returned
+ newforallOp.getBody()->addArgument(newOuts.back().getType(),
+ newOuts.back().getLoc());
+
+ // Fix terminator
+ scf::InParallelOp terminatorOp = newforallOp.getTerminator();
+ SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
+ terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
+ Operation *firstYieldOp = yieldingOps.front();
+ rewriter.setInsertionPoint(firstYieldOp);
+ Value src = tileAndFuseResult.tiledValues[0];
+ Value dst = newforallOp.getOutputBlockArguments().back();
+ SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
+ dst, offsets, sizes, strides);
+
+ for (auto result : llvm::enumerate(forallOp.getResults())) {
+ rewriter.replaceAllUsesWith(result.value(),
+ newforallOp->getResult(result.index()));
+ }
+ rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
+ newforallOp->getResults().back(),
+ [&](OpOperand &use) {
+ Operation *user = use.getOwner();
+ return dominatedUsers.contains(user);
+ });
+ return newforallOp;
+}
+
/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
-static SmallVector<Operation *>
+/// If tiled op has uses that are dominated by `containingOp`, return
+/// a new `containingOp` with results of the fused op appended to
+/// results of the `containingOp` or nullptr if there are no dominated uses.
+static std::tuple<SmallVector<Operation *>, Operation *>
tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp, Operation *containingOp) {
LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
@@ -386,10 +458,13 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
+ SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
+
FailureOr<TilingResult> tileAndFuseResult =
- tileableProducer.generateResultTileValue(rewriter, resultNumber,
- sliceOpToTile.getMixedOffsets(),
- sliceOpToTile.getMixedSizes());
+ tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
+ sizes);
+
if (failed(tileAndFuseResult)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
@@ -408,7 +483,13 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
assert(succeeded(maybeRankReduced) && "unexpected shape");
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
- return tileAndFuseResult->tiledOps;
+
+ // Add new outputs to containing op, if required
+ Operation *newContainingOp = replaceForAllWithNewSignature(
+ rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
+ resultNumber, offsets, sizes);
+
+ return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
}
/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
@@ -635,11 +716,15 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
// cases, we can tile/clone once and reuse the value for each use.
// Futhermore, producers should then be traversed according to a
// topological sorting.
- SmallVector<Operation *> tiledOps =
+ auto [tiledOps, newContainingOp] =
tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
if (!tiledOps.empty()) {
LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
fusedOps.append(tiledOps);
+ if (newContainingOp) {
+ rewriter.eraseOp(containingOp);
+ containingOp = newContainingOp;
+ }
continue;
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 29cacb474b968..f3f480247f7ff 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -116,7 +116,7 @@ module {
// CHECK: scf.forall {{.*}} -> (tensor<?xf32>) {
%2 = scf.forall (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor<?xf32>) {
%5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor<?xf32> to tensor<f32>
-
+
// CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor<?xf32> to tensor<1xf32>
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32>
// CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor<f32>
@@ -288,3 +288,200 @@ module {
transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
}
}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_multi_output_op_multi_use
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_multi_output_op_multi_use(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
+ -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+
+ // CHECK: %[[G0:.*]]:2 = linalg.generic
+ %0:2 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%in : tensor<?xf32>) outs(%out_1, %out_3 : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %d = arith.addf %a, %b : f32
+ %e = arith.addf %d, %c : f32
+ linalg.yield %d, %e : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+ %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
+
+ %1 = affine.apply #map0()[%d0, %idx]
+
+ // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
+ // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+ %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+ // CHECK: %[[I0:.*]] = affine.apply {{.*}}
+ %3 = affine.apply #map1(%i)[%idx]
+ // CHECK: %[[I1:.*]] = affine.min {{.*}}
+ %4 = affine.min #map2(%i)[%d0, %idx]
+ %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}}
+ %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[T1]]#0 into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: return %[[R0]]#0, %[[R0]]#1, %[[G0]]#1
+ func.return %2, %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>
+ // CHECK: }
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+
+ // linalg.generic is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+ }
+}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_mixed_dominating_uses
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_mixed_dominating_uses(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
+ -> (tensor<?xf32>, tensor<?xf32>) {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+
+ // CHECK: %[[G0:.*]] = linalg.generic
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%in : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32):
+ %d = arith.addf %a, %b : f32
+ linalg.yield %d : f32
+ } -> tensor<?xf32>
+ // CHECK: %[[D0:.*]] = tensor.dim %[[G0]]
+ %d0 = tensor.dim %0, %c0 : tensor<?xf32>
+
+ %1 = affine.apply #map0()[%d0, %idx]
+
+ // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
+ // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+ %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+ // CHECK: %[[I0:.*]] = affine.apply {{.*}}
+ %3 = affine.apply #map1(%i)[%idx]
+ // CHECK: %[[I1:.*]] = affine.min {{.*}}
+ %4 = affine.min #map2(%i)[%d0, %idx]
+ %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
+ %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: return %[[R0]]#0, %[[R0]]#1
+ func.return %2, %0 : tensor<?xf32>, tensor<?xf32>
+ // CHECK: }
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+
+ // linalg.generic is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+ }
+}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+#map3 = affine_map<(d0, d1) -> (d0, d1)>
+#map4 = affine_map<(d0, d1) -> (d0)>
+
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_reductions
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?x?xf32>
+ // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_reductions(%idx: index, %in: tensor<?x?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
+ -> (tensor<?xf32>, tensor<?xf32>) {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+
+ %0 = linalg.generic {
+ indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]
+ } ins(%in : tensor<?x?xf32>) outs(%out_1 : tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32):
+ %d = arith.maxf %a, %b : f32
+ linalg.yield %d : f32
+ } -> tensor<?xf32>
+ %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
+
+ %1 = affine.apply #map0()[%d0, %idx]
+
+ // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
+ // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+ %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+ // CHECK: %[[I0:.*]] = affine.apply {{.*}}
+ %3 = affine.apply #map1(%i)[%idx]
+ // CHECK: %[[I1:.*]] = affine.min {{.*}}
+ %4 = affine.min #map2(%i)[%d0, %idx]
+ %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
+ %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: return %[[R0]]#0, %[[R0]]#1
+ func.return %2, %0 : tensor<?xf32>, tensor<?xf32>
+ // CHECK: }
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+
+ // linalg.generic is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+ }
+}
More information about the Mlir-commits
mailing list