[Mlir-commits] [mlir] bc882ed - [mlir][linalg][transform] Add fuse_into_containing op
Matthias Springer
llvmlistbot at llvm.org
Fri Jul 22 04:58:45 PDT 2022
Author: Matthias Springer
Date: 2022-07-22T13:55:04+02:00
New Revision: bc882ed21fc7ae4ad934bf432c2431fa27cad556
URL: https://github.com/llvm/llvm-project/commit/bc882ed21fc7ae4ad934bf432c2431fa27cad556
DIFF: https://github.com/llvm/llvm-project/commit/bc882ed21fc7ae4ad934bf432c2431fa27cad556.diff
LOG: [mlir][linalg][transform] Add fuse_into_containing op
This op fuses a given payload op into a given container op. Inside the container, all uses of the producer are replaced (fused) with the newly inserted op. If the producer is tileable and accessed via a tensor.extract_slice, the new op computes only the requested slice ("tile and fuse"). Otherwise, the entire tensor value is computed inside the container ("clone and fuse").
Differential Revision: https://reviews.llvm.org/D130244
Added:
mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b8bcf136ee383..f97061d516d58 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -65,6 +65,54 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
let hasVerifier = 1;
}
+def FuseIntoContainingOp :
+ Op<Transform_Dialect, "structured.fuse_into_containing_op",
+ [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{Fuse a producer into a containing operation.}];
+
+ let summary = [{
+ Fuses the `producer_op` into the `containing_op`. Only producers with a
+ single result are supported at the moment. Returns a handle to the fused
+ ops.
+
+ The producer is typically a slice of a tileable op (i.e., implements
+ TilingInterface). In that case, this transform computes the accessed
+ producer slice inside of the containing op ("tile and fuse"). Otherwise,
+ the entire producer is cloned inside the containing op ("clone and fuse").
+
+ The containing op handle must be associated with exactly one payload op. The
+ producer op handle may be associated with multiple payload ops. This
+ transform fuses producers one-by-one, always picking an unspecified producer
+ that has at least one use inside the containing op among the
+ producers.
+
+ Note: If a producer has multiple uses inside the containing op, it is
+ currently tiled and/or cloned multiple times into the containing op.
+ TODO: Reuse already fused OpResults instead of tiling/cloning a second time
+ when possible. Fuse producers according to a topological sorting to achieve
+ the largest amount of reuse.
+
+ #### Return modes
+
+ If at least one producer could not be fused, this operation fails silently.
+ This is the case when tiling fails or when no producer op could be found
+ among the remaining producers that has at least one use within the
+ containing op. I.e., "producers" that are not consumed within the containing
+ op are rejected by this operation. This operation reads and frees the
+ producer handle. It reads the containing op handle.
+ }];
+
+ let arguments = (ins Arg<PDL_Operation, "",
+ [TransformMappingRead,
+ TransformMappingFree]>:$producer_op,
+ Arg<PDL_Operation, "",
+ [TransformMappingRead]>:$containing_op);
+ let results = (outs Res<PDL_Operation, "",
+ [TransformMappingAlloc,
+ TransformMappingWrite]>:$fused_op);
+ let assemblyFormat = "$producer_op `into` $containing_op attr-dict";
+}
+
def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 070b1fc4eb821..a74f3d4e3d3c6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -213,6 +213,160 @@ LogicalResult transform::FuseOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// FuseIntoContainingOp
+//===----------------------------------------------------------------------===//
+
+static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
+ Operation *containingOp,
+ RewriterBase &rewriter) {
+ auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
+ if (!tileableProducer)
+ return failure();
+
+ // Search the producer slices accessed within the containing operation.
+ // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
+ // evolve into an interface.
+ SmallVector<tensor::ExtractSliceOp> sliceOps;
+ for (Operation *user : tileableProducer->getUsers()) {
+ auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
+ if (!sliceOp)
+ continue;
+ if (!containingOp->isProperAncestor(sliceOp))
+ continue;
+ sliceOps.push_back(sliceOp);
+ }
+
+ // Check for a non-empty list of fusion opportunities.
+ if (sliceOps.empty())
+ return failure();
+
+ SmallVector<Value> destinationOperands =
+ tileableProducer.getDestinationOperands(rewriter);
+
+ // Try to fuse the producer in-place.
+ SmallVector<Operation *> fusedOps;
+ for (tensor::ExtractSliceOp sliceOp : sliceOps) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(sliceOp);
+
+ // Tile the producer.
+ FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
+ rewriter, /*resultNumber=*/0, destinationOperands,
+ sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true);
+ if (failed(tiledProducer))
+ return failure();
+ fusedOps.push_back(tiledProducer->getDefiningOp());
+ }
+
+ // Replace the extract op.
+ for (const auto &en : enumerate(sliceOps))
+ rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
+ return fusedOps;
+}
+
+static FailureOr<SmallVector<Operation *>>
+cloneAndFuse(Operation *producerOp, Operation *containingOp,
+ RewriterBase &rewriter) {
+ // Gather all uses inside the containing op.
+ SmallVector<OpOperand *> uses;
+ for (OpResult result : producerOp->getOpResults())
+ for (OpOperand &use : result.getUses())
+ if (containingOp->isProperAncestor(use.getOwner()))
+ uses.push_back(&use);
+
+ // Check for a non-empty list of fusion opportunities.
+ if (uses.empty())
+ return failure();
+
+ // Clone and fuse inside the containing op.
+ SmallVector<Operation *> fusedOps;
+ for (OpOperand *use : uses) {
+ unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(use->getOwner());
+ Operation *cloned = rewriter.clone(*producerOp);
+ rewriter.updateRootInPlace(
+ use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
+ fusedOps.push_back(cloned);
+ }
+
+ return fusedOps;
+}
+
+DiagnosedSilenceableFailure
+transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Operation *> fusedOps;
+ ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
+ for (Operation *producerOp : producerOps) {
+ if (producerOp->getNumResults() != 1) {
+ Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
+ diag << "op with != 1 results not supported";
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ }
+ }
+ ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
+ if (containingOps.size() != 1)
+ return DiagnosedSilenceableFailure(
+ this->emitOpError("requires exactly one containing_op handle"));
+ Operation *containingOp = containingOps.front();
+
+ // Helper function to find the next producer that should be fused. Take any
+ // producer that has a use inside the containing op.
+ SmallVector<Operation *> remainingProducers(producerOps.begin(),
+ producerOps.end());
+ auto getNextProducer = [&]() -> FailureOr<Operation *> {
+ for (const auto &it : enumerate(remainingProducers)) {
+ Operation *producerOp = it.value();
+ bool hasUseInContainingOp =
+ any_of(producerOp->getUsers(), [&](Operation *op) {
+ return containingOp->isProperAncestor(op);
+ });
+ // TODO: When resolving the TODO below (no duplicate ops), take an op that
+ // has no use among the remaining producers. This is a topological
+ // sorting.
+ if (hasUseInContainingOp) {
+ remainingProducers.erase(remainingProducers.begin() + it.index());
+ return producerOp;
+ }
+ }
+ return failure();
+ };
+
+ IRRewriter rewriter(getContext());
+ while (!remainingProducers.empty()) {
+ auto nextProducer = getNextProducer();
+ if (failed(nextProducer)) {
+ Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note);
+ diag << "could not fuse ops into container";
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ }
+
+ Operation *producerOp = *nextProducer;
+ // TODO: If there are multiple uses of the producer in the containing op, we
+ // currently tile/clone the op multiple times (once per use). In some cases,
+ // we can tile/clone once and reuse the value for each use. Futhermore,
+ // producers should then be traversed according to a topological sorting.
+ auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
+ if (succeeded(tiled))
+ fusedOps.append(*tiled);
+
+ auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
+ if (succeeded(cloned))
+ fusedOps.append(*cloned);
+
+ if (failed(tiled) && failed(cloned)) {
+ Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
+ diag << "could not fuse into containing op";
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ }
+ }
+
+ results.set(getFusedOp().cast<OpResult>(), fusedOps);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// GeneralizeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
new file mode 100644
index 0000000000000..8b95cec4fcb8c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_op
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_op(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+ %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
+ %1 = affine.apply #map0()[%d0, %arg0]
+
+ // CHECK: scf.foreach_thread {{.*}} {
+ %2 = scf.foreach_thread (%arg3) in (%1) -> (tensor<?xf32>) {
+ %3 = affine.apply #map1(%arg3)[%arg0]
+ %4 = affine.min #map2(%arg3)[%d0, %arg0]
+ %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
+ %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]
+ %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ scf.foreach_thread.perform_concurrently {
+ tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: }
+ func.return %2 : tensor<?xf32>
+ }
+
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
+ %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
+
+ // linalg.fill is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ }
+ }
+}
+
+// -----
+
+#map0 = affine_map<()[s0] -> (64 ceildiv s0)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
+
+module {
+ // CHECK-LABEL: func.func @fuse_untileable_op
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32>
+ // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32>
+ func.func @fuse_untileable_op(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ %0 = linalg.init_tensor [%arg0] : tensor<?xf32>
+ %1 = affine.apply #map0()[%arg0]
+
+ // CHECK: scf.foreach_thread {{.*}} {
+ %2 = scf.foreach_thread (%arg3) in (%1) -> (tensor<64xf32>) {
+ // CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor
+ %3 = affine.apply #map1(%arg3)[%arg0]
+ %4 = affine.min #map2(%arg3)[%arg0]
+ %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<64xf32> to tensor<?xf32>
+
+ // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]]
+ %7 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ scf.foreach_thread.perform_concurrently {
+ tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor<?xf32> into tensor<64xf32>
+ }
+ }
+ // CHECK: }
+
+ func.return %2 : tensor<64xf32>
+ }
+
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.init_tensor"]} in %arg1
+ %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
+
+ // linalg.init_tensor is not tileable. The op is cloned and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ }
+ }
+}
More information about the Mlir-commits
mailing list