[Mlir-commits] [mlir] bc882ed - [mlir][linalg][transform] Add fuse_into_containing op

Matthias Springer llvmlistbot at llvm.org
Fri Jul 22 04:58:45 PDT 2022


Author: Matthias Springer
Date: 2022-07-22T13:55:04+02:00
New Revision: bc882ed21fc7ae4ad934bf432c2431fa27cad556

URL: https://github.com/llvm/llvm-project/commit/bc882ed21fc7ae4ad934bf432c2431fa27cad556
DIFF: https://github.com/llvm/llvm-project/commit/bc882ed21fc7ae4ad934bf432c2431fa27cad556.diff

LOG: [mlir][linalg][transform] Add fuse_into_containing op

This op fuses a given payload op into a given container op. Inside the container, all uses of the producer are replaced (fused) with the newly inserted op. If the producer is tileable and accessed via a tensor.extract_slice, the new op computes only the requested slice ("tile and fuse"). Otherwise, the entire tensor value is computed inside the container ("clone and fuse").

Differential Revision: https://reviews.llvm.org/D130244

Added: 
    mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b8bcf136ee383..f97061d516d58 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -65,6 +65,54 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
   let hasVerifier = 1;
 }
 
+def FuseIntoContainingOp :
+    Op<Transform_Dialect, "structured.fuse_into_containing_op",
+      [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{Fuse a producer into a containing operation.}];
+
+  let summary = [{
+    Fuses the `producer_op` into the `containing_op`. Only producers with a
+    single result are supported at the moment. Returns a handle to the fused
+    ops.
+
+    The producer is typically a slice of a tileable op (i.e., implements
+    TilingInterface). In that case, this transform computes the accessed
+    producer slice inside of the containing op ("tile and fuse"). Otherwise,
+    the entire producer is cloned inside the containing op ("clone and fuse").
+
+    The containing op handle must be associated with exactly one payload op. The
+    producer op handle may be associated with multiple payload ops. This
+    transform fuses producers one-by-one, always picking an unspecified producer
+    that has at least one use inside the containing op among the 
+    producers.
+
+    Note: If a producer has multiple uses inside the containing op, it is
+    currently tiled and/or cloned multiple times into the containing op.
+    TODO: Reuse already fused OpResults instead of tiling/cloning a second time
+    when possible. Fuse producers according to a topological sorting to achieve
+    the largest amount of reuse.
+
+    #### Return modes
+
+    If at least one producer could not be fused, this operation fails silently.
+    This is the case when tiling fails or when no producer op could be found
+    among the remaining producers that has at least one use within the
+    containing op. I.e., "producers" that are not consumed within the containing
+    op are rejected by this operation. This operation reads and frees the
+    producer handle. It reads the containing op handle.
+  }];
+
+  let arguments = (ins Arg<PDL_Operation, "",
+                           [TransformMappingRead,
+                            TransformMappingFree]>:$producer_op,
+                       Arg<PDL_Operation, "",
+                           [TransformMappingRead]>:$containing_op);
+  let results = (outs Res<PDL_Operation, "",
+                          [TransformMappingAlloc,
+                           TransformMappingWrite]>:$fused_op);
+  let assemblyFormat = "$producer_op `into` $containing_op attr-dict";
+}
+
 def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 070b1fc4eb821..a74f3d4e3d3c6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -213,6 +213,160 @@ LogicalResult transform::FuseOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// FuseIntoContainingOp
+//===----------------------------------------------------------------------===//
+
+static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
+                                                       Operation *containingOp,
+                                                       RewriterBase &rewriter) {
+  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
+  if (!tileableProducer)
+    return failure();
+
+  // Search the producer slices accessed within the containing operation.
+  // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
+  // evolve into an interface.
+  SmallVector<tensor::ExtractSliceOp> sliceOps;
+  for (Operation *user : tileableProducer->getUsers()) {
+    auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
+    if (!sliceOp)
+      continue;
+    if (!containingOp->isProperAncestor(sliceOp))
+      continue;
+    sliceOps.push_back(sliceOp);
+  }
+
+  // Check for a non-empty list of fusion opportunities.
+  if (sliceOps.empty())
+    return failure();
+
+  SmallVector<Value> destinationOperands =
+      tileableProducer.getDestinationOperands(rewriter);
+
+  // Try to fuse the producer in-place.
+  SmallVector<Operation *> fusedOps;
+  for (tensor::ExtractSliceOp sliceOp : sliceOps) {
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(sliceOp);
+
+    // Tile the producer.
+    FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
+        rewriter, /*resultNumber=*/0, destinationOperands,
+        sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true);
+    if (failed(tiledProducer))
+      return failure();
+    fusedOps.push_back(tiledProducer->getDefiningOp());
+  }
+
+  // Replace the extract op.
+  for (const auto &en : enumerate(sliceOps))
+    rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
+  return fusedOps;
+}
+
+static FailureOr<SmallVector<Operation *>>
+cloneAndFuse(Operation *producerOp, Operation *containingOp,
+             RewriterBase &rewriter) {
+  // Gather all uses inside the containing op.
+  SmallVector<OpOperand *> uses;
+  for (OpResult result : producerOp->getOpResults())
+    for (OpOperand &use : result.getUses())
+      if (containingOp->isProperAncestor(use.getOwner()))
+        uses.push_back(&use);
+
+  // Check for a non-empty list of fusion opportunities.
+  if (uses.empty())
+    return failure();
+
+  // Clone and fuse inside the containing op.
+  SmallVector<Operation *> fusedOps;
+  for (OpOperand *use : uses) {
+    unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(use->getOwner());
+    Operation *cloned = rewriter.clone(*producerOp);
+    rewriter.updateRootInPlace(
+        use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
+    fusedOps.push_back(cloned);
+  }
+
+  return fusedOps;
+}
+
+DiagnosedSilenceableFailure
+transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
+                                       transform::TransformState &state) {
+  SmallVector<Operation *> fusedOps;
+  ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
+  for (Operation *producerOp : producerOps) {
+    if (producerOp->getNumResults() != 1) {
+      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
+      diag << "op with != 1 results not supported";
+      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+    }
+  }
+  ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
+  if (containingOps.size() != 1)
+    return DiagnosedSilenceableFailure(
+        this->emitOpError("requires exactly one containing_op handle"));
+  Operation *containingOp = containingOps.front();
+
+  // Helper function to find the next producer that should be fused. Take any
+  // producer that has a use inside the containing op.
+  SmallVector<Operation *> remainingProducers(producerOps.begin(),
+                                              producerOps.end());
+  auto getNextProducer = [&]() -> FailureOr<Operation *> {
+    for (const auto &it : enumerate(remainingProducers)) {
+      Operation *producerOp = it.value();
+      bool hasUseInContainingOp =
+          any_of(producerOp->getUsers(), [&](Operation *op) {
+            return containingOp->isProperAncestor(op);
+          });
+      // TODO: When resolving the TODO below (no duplicate ops), take an op that
+      // has no use among the remaining producers. This is a topological
+      // sorting.
+      if (hasUseInContainingOp) {
+        remainingProducers.erase(remainingProducers.begin() + it.index());
+        return producerOp;
+      }
+    }
+    return failure();
+  };
+
+  IRRewriter rewriter(getContext());
+  while (!remainingProducers.empty()) {
+    auto nextProducer = getNextProducer();
+    if (failed(nextProducer)) {
+      Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note);
+      diag << "could not fuse ops into container";
+      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+    }
+
+    Operation *producerOp = *nextProducer;
+    // TODO: If there are multiple uses of the producer in the containing op, we
+    // currently tile/clone the op multiple times (once per use). In some cases,
+    // we can tile/clone once and reuse the value for each use. Futhermore,
+    // producers should then be traversed according to a topological sorting.
+    auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
+    if (succeeded(tiled))
+      fusedOps.append(*tiled);
+
+    auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
+    if (succeeded(cloned))
+      fusedOps.append(*cloned);
+
+    if (failed(tiled) && failed(cloned)) {
+      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
+      diag << "could not fuse into containing op";
+      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+    }
+  }
+
+  results.set(getFusedOp().cast<OpResult>(), fusedOps);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // GeneralizeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
new file mode 100644
index 0000000000000..8b95cec4fcb8c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_op
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<?xf32>
+  func.func @fuse_tileable_op(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+    %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+    %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
+    %1 = affine.apply #map0()[%d0, %arg0]
+
+    // CHECK: scf.foreach_thread {{.*}} {
+    %2 = scf.foreach_thread (%arg3) in (%1)  -> (tensor<?xf32>) {
+      %3 = affine.apply #map1(%arg3)[%arg0]
+      %4 = affine.min #map2(%arg3)[%d0, %arg0]
+      %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}]
+      // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
+      %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]
+      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      scf.foreach_thread.perform_concurrently {
+        tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    // CHECK: }
+    func.return %2 : tensor<?xf32>
+  }
+
+  transform.with_pdl_patterns {
+  ^bb0(%arg0: !pdl.operation):
+    transform.sequence %arg0 {
+    ^bb1(%arg1: !pdl.operation):
+      %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
+      %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
+
+      // linalg.fill is tileable. The op is tiled and fused.
+      transform.structured.fuse_into_containing_op %0 into %1
+    }
+  }
+}
+
+// -----
+
+#map0 = affine_map<()[s0] -> (64 ceildiv s0)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
+
+module {
+  // CHECK-LABEL: func.func @fuse_untileable_op
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<64xf32>
+  //  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<64xf32>
+  func.func @fuse_untileable_op(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    %0 = linalg.init_tensor [%arg0] : tensor<?xf32>
+    %1 = affine.apply #map0()[%arg0]
+
+    // CHECK: scf.foreach_thread {{.*}} {
+    %2 = scf.foreach_thread (%arg3) in (%1)  -> (tensor<64xf32>) {
+      // CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor
+      %3 = affine.apply #map1(%arg3)[%arg0]
+      %4 = affine.min #map2(%arg3)[%arg0]
+      %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<64xf32> to tensor<?xf32>
+
+      // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]]
+      %7 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      scf.foreach_thread.perform_concurrently {
+        tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor<?xf32> into tensor<64xf32>
+      }
+    }
+    // CHECK: }
+
+    func.return %2 : tensor<64xf32>
+  }
+
+  transform.with_pdl_patterns {
+  ^bb0(%arg0: !pdl.operation):
+    transform.sequence %arg0 {
+    ^bb1(%arg1: !pdl.operation):
+      %0 = transform.structured.match ops{["linalg.init_tensor"]} in %arg1
+      %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
+
+      // linalg.init_tensor is not tileable. The op is cloned and fused.
+      transform.structured.fuse_into_containing_op %0 into %1
+    }
+  }
+}


        


More information about the Mlir-commits mailing list