[Mlir-commits] [mlir] 6db928b - [mlir][linalg] Fusion on tensors.

Tobias Gysi llvmlistbot at llvm.org
Mon Sep 20 07:46:06 PDT 2021


Author: Tobias Gysi
Date: 2021-09-20T14:45:34Z
New Revision: 6db928b8f31b17caf205eee9c95bb817e51a3f2c

URL: https://github.com/llvm/llvm-project/commit/6db928b8f31b17caf205eee9c95bb817e51a3f2c
DIFF: https://github.com/llvm/llvm-project/commit/6db928b8f31b17caf205eee9c95bb817e51a3f2c.diff

LOG: [mlir][linalg] Fusion on tensors.

Add a new version of fusion on tensors that supports the following scenarios:
- support input and output operand fusion
- fuse a producer result passed in via tile loop iteration arguments (update the tile loop iteration arguments)
- supports only linalg operations on tensors
- supports only scf::for
- cannot add an output to the tile loop nest

The LinalgTileAndFuseOnTensors pass tiles the root operation and fuses its producers.

Reviewed By: nicolasvasilache, mravishankar

Differential Revision: https://reviews.llvm.org/D109766

Added: 
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 12b55a153d16..063bf71d0523 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -77,6 +77,9 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
 /// work on primitive types, if possible.
 std::unique_ptr<Pass> createLinalgDetensorizePass();
 
+/// Create a pass to tile a LinalgOp and fuse its producers.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFuseTensorOpsPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index ecde91ff5120..acd18ff7977f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -243,4 +243,18 @@ def LinalgDetensorize : FunctionPass<"linalg-detensorize"> {
   }];
 }
 
+def LinalgTileAndFuseTensorOps
+    : FunctionPass<"linalg-tile-and-fuse-tensor-ops"> {
+  let summary = "Tile a LinalgOp and fuse its producers.";
+  let constructor = "mlir::createLinalgTileAndFuseTensorOpsPass()";
+  let options = [
+    ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes",
+           "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+    ListOption<"tileInterchange", "tile-interchange", "int64_t",
+           "Tile loop interchange",
+           "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+  ];
+  let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 8776b7404542..a8d4edf6c072 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -172,6 +172,64 @@ Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
                                           OpResult producerOpResult,
                                           OpOperand &consumerOpOperand);
 
+//===----------------------------------------------------------------------===//
+// Fusion on tensor utilities
+//===----------------------------------------------------------------------===//
+
+/// A struct to manage the tile loop nest specific information.
+class TileLoopNest {
+public:
+  TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
+
+  /// Tile the root operation using the given `tileSizes` and `tileInterchange`.
+  LogicalResult tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
+                           ArrayRef<int64_t> tileInterchange);
+
+  /// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the
+  /// fused producer of fails if fusion is not possible.
+  // TODO: add replace uses callback to support passes and patterns.
+  FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *rootOpOperand);
+
+  /// Returns the tiled root operation.
+  LinalgOp getRootOp() { return rootOp; }
+
+private:
+  /// Returns true if the tile loop nest has no tile loops.
+  bool isEmpty();
+
+  /// Returns true if the tile loop nest invariants are satisfied:
+  /// - The number of tile loop operations and dimensions match.
+  /// - The innermost tile loop is the parent of `tiledOp`.
+  /// - The tile loops are directly nested.
+  // TODO: relax to support additional control flow, e.g., IfOp.
+  bool isValid();
+
+  /// Searches the block arguments tied to a block argument `bbArg` of the
+  /// innermost tile loop. Returns the block argument from outermost to
+  /// innermost or an empty vector if none are found.
+  SmallVector<BlockArgument> getTiedBBArgs(BlockArgument bbArg);
+
+  /// Returns the iteration argument of the outermost tile loop mapped to a
+  /// block argument `bbArg` of the innermost tile loop.
+  OpOperand *getTiedIterArg(BlockArgument bbArg);
+
+  /// Returns true if `bbArg` has other used than `sliceOp` and its
+  /// dependencies. Only if there are no other uses, the producer output
+  /// iteration argument may reused to pass the producer result after fusion.
+  bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
+
+  LinalgOp rootOp;
+  SmallVector<scf::ForOp> loopOps;
+  SmallVector<int64_t> loopDims;
+};
+
+/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
+/// `tileSizes` and `tileInterchange` parameters to control the tiling.
+FailureOr<TileLoopNest>
+tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
+                             ArrayRef<int64_t> tileSizes,
+                             ArrayRef<int64_t> tileInterchange);
+
 //===----------------------------------------------------------------------===//
 // Distribution utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ebc6443d683c..bd4487b62253 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   ElementwiseOpFusion.cpp
   ElementwiseToLinalg.cpp
   Fusion.cpp
+  FusionOnTensors.cpp
   Generalization.cpp
   Hoisting.cpp
   InlineScalarOperands.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
new file mode 100644
index 000000000000..d3069ee40ca7
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -0,0 +1,481 @@
+//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements linalg fusion on tensors
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace linalg;
+
+//===----------------------------------------------------------------------===//
+// StructuredOp specific helpers.
+//===----------------------------------------------------------------------===//
+
+/// Relate the producer to the consumer loop iterations that access the same
+/// producer result element:
+///  consumerToProducerLoops =
+///   inverse(producerIndexingMap).compose(consumerIndexingMap).
+/// Return `consumerToProducerLoops` or none if the inversion fails.
+static Optional<AffineMap>
+getConsumerToProducerLoopsMap(AffineMap producerIndexingMap,
+                              AffineMap consumerIndexingMap) {
+  assert(consumerIndexingMap.getNumResults() ==
+             producerIndexingMap.getNumResults() &&
+         "expect the number of indexing map results to match");
+  // Ensure the producer indexing map is a projected permutation.
+  if (!producerIndexingMap.isProjectedPermutation())
+    return None;
+  AffineMap inverseIndexingMap =
+      inverseAndBroadcastProjectedPermuation(producerIndexingMap);
+  return inverseIndexingMap.compose(consumerIndexingMap);
+}
+
+/// Returns the producer result slice dimensions tiled by the tile loop nest or
+/// an empty vector if `getConsumerToProducerLoopsMap` returns none.
+// TODO: replace by Fourier-Motzkin and/or compute starting from consumer.
+SmallVector<int64_t> getTiledSliceDims(OpResult producerResult,
+                                       OpOperand *consumerOperand,
+                                       ArrayRef<int64_t> tiledLoopDims) {
+  LinalgOp consumerOp = consumerOperand->getOwner();
+  LinalgOp producerOp = producerResult.getOwner();
+  OpOperand *opOperand =
+      producerOp.getOutputOperand(producerResult.getResultNumber());
+
+  // Compute the `consumerToProducerLoopsMap` and exit if the computation fails.
+  AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(opOperand);
+  Optional<AffineMap> consumerToProducerLoopsMap =
+      getConsumerToProducerLoopsMap(
+          producerIndexingMap, consumerOp.getTiedIndexingMap(consumerOperand));
+  if (!consumerToProducerLoopsMap.hasValue())
+    return {};
+
+  // Compute the set of tiled producer loops.
+  DenseSet<int64_t> tiledProducerLoops;
+  for (auto en : enumerate(consumerToProducerLoopsMap->getResults())) {
+    for (int64_t dim : tiledLoopDims) {
+      if (en.value().isFunctionOfDim(dim))
+        tiledProducerLoops.insert(en.index());
+    }
+  }
+
+  // Compute the slice dimensions for the tiled producer loops.
+  SmallVector<int64_t> tiledSliceDims;
+  for (auto en : enumerate(producerIndexingMap.getResults())) {
+    auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
+    if (dimExpr && tiledProducerLoops.count(dimExpr.getPosition()) != 0)
+      tiledSliceDims.push_back(en.index());
+  }
+  return tiledSliceDims;
+}
+
+/// Returns the producer fused in place of `sliceOp`. Tile the producer operands
+/// along the `tiledSliceDims` and clone the producer. Consider the case of
+/// fusion of an output tensor:
+/// ```
+/// %1 = producer ins(...) outs(%0)
+/// %2 = consumer ins(...) outs(%1)
+/// ```
+/// When consumer is tiled, %1 appears in the loop iter_args:
+/// ```
+/// %1 = producer ins(...) outs(%0)
+/// %2 = scf.for ... iter_args(%1) .. (%bbarg) {
+///   %t1 = tensor.extract_slice %bbarg[..]
+///   %t2 = consumer ins(...) outs(%t1)
+///   %r = tensor.insert_slice %t2, %bbarg[...]
+/// }
+/// ```
+/// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0):
+/// ```
+/// %2 = scf.for ... iter_args(%0) .. (%bbarg) {
+///   %t0 = tensor.extract_slice %bbarg[..]
+///   %t1 = producer ins(...) outs(%t0)
+///   %t2 = consumer ins(...) outs(%t1)
+///   %r = tensor.insert_slice %t2, %bbarg[...]
+/// }
+/// ```
+/// This transformation is only valid if %bbarg is exclusively used by the
+/// output ExtractSliceOp / InsertSliceOp pair, which is checked by the
+/// `fuseProducer` method.
+/// TODO: instead of check and failure, insert new iter_args each time a
+/// producer is fused into a consumer and fold away unused iter_args.
+static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
+                                 tensor::ExtractSliceOp sliceOp,
+                                 ArrayRef<int64_t> tiledSliceDims,
+                                 OpOperand *iterArg) {
+  // Clone the producer after `sliceOp` since the slice may be reused to pass in
+  // the producer result.
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointAfter(sliceOp);
+
+  // Get the producer.
+  LinalgOp producerOp = producerResult.getOwner();
+  Location loc = producerOp.getLoc();
+
+  // Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
+  SmallVector<Value> producerLoopBounds;
+  transform(producerOp.createLoopRanges(b, loc),
+            std::back_inserter(producerLoopBounds),
+            [](Range range) { return range.size; });
+  SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
+
+  // Get the producer result indexing map.
+  AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
+      producerOp.getOutputOperand(producerResult.getResultNumber()));
+
+  // Tile the producer operands given the `sliceOp` ranges. Iterate the
+  // `tiledSliceDims` and store the tile offset and size for the tiled slice
+  // dimension. Assumes the mapping from slice dimensions to producer loops is a
+  // permutation.
+  auto zero = b.create<ConstantIndexOp>(loc, 0);
+  SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr);
+  SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero);
+  SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr);
+  for (int64_t tiledSliceDim : tiledSliceDims) {
+    AffineExpr result = producerIndexingMap.getResults()[tiledSliceDim];
+    assert(result.isa<AffineDimExpr>() &&
+           "expect producer indexing map is a projected permutation");
+    int64_t tiledProducerLoop = result.cast<AffineDimExpr>().getPosition();
+    tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset;
+    tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size;
+    allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
+  }
+  erase_value(tileIvs, nullptr);
+  SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
+  tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
+                                  tileSizes, producerLoopBounds);
+
+  // Output fusion has to update the iteration arguments of the tile loop nest.
+  // In particular, the iteration argument of the outermost tile loop needs to
+  // be set to the producer output instead of the producer result and `clonedOp`
+  // shall use the existing `sliceOp` result instead of the tiled producer
+  // output operand.
+  if (iterArg) {
+    OpOperand *outputOperand =
+        producerOp.getOutputOperand(producerResult.getResultNumber());
+    iterArg->set(outputOperand->get());
+    tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult();
+  }
+
+  // Clone the producer using the tiled producer operands.
+  TypeRange resultTypes = ValueRange(tiledOperands)
+                              .take_back(producerOp.getNumOutputs())
+                              .getTypes();
+  LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
+
+  return clonedOp;
+}
+
+//===----------------------------------------------------------------------===//
+// TileLoopNest specific helpers.
+//===----------------------------------------------------------------------===//
+
+bool TileLoopNest::isEmpty() { return loopOps.empty(); }
+
+bool TileLoopNest::isValid() {
+  // Check if the number of `tileLoopOps` and `tileLoopDims` match.
+  if (loopOps.size() != loopDims.size())
+    return false;
+
+  // Check if the innermost tile loop is the parent of `tiledOp`.
+  if (rootOp->getParentOp() != loopOps.back())
+    return false;
+
+  // Check if the tile loops are directly nested.
+  return std::adjacent_find(loopOps.begin(), loopOps.end(),
+                            [](Operation *op1, Operation *op2) {
+                              return op1 != op2->getParentOp();
+                            }) == loopOps.end();
+}
+
+SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
+  assert(bbArg && "expect the block argument to be non-zero");
+  SmallVector<BlockArgument> bbArgs;
+
+  // Search all tile loop block arguments from inner to outer.
+  for (auto tileLoop : reverse(loopOps)) {
+    if (bbArg.getOwner()->getParentOp() != tileLoop)
+      return {};
+    bbArgs.push_back(bbArg);
+    OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg);
+    bbArg = iterArg->get().dyn_cast<BlockArgument>();
+  }
+
+  // Reverse the block arguments to order them from outer to inner.
+  return {bbArgs.rbegin(), bbArgs.rend()};
+}
+
+OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) {
+  // Search all block arguments and return the matching iteration argument.
+  SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
+  if (bbArgs.size() != loopOps.size())
+    return nullptr;
+  return &loopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
+}
+
+bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
+                                tensor::ExtractSliceOp sliceOp) {
+  // Check the innermost block argument is either used by the ExtractSliceOp
+  // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses
+  // conservatively.
+  for (Operation *op : bbArg.getUsers()) {
+    if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op))
+      return false;
+    if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
+      if (extractSliceOp != sliceOp)
+        return false;
+    }
+    if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
+      SetVector<Operation *> backwardSlice;
+      getBackwardSlice(insertSliceOp.source(), &backwardSlice,
+                       [](Operation *op) {
+                         return isa<LinalgOp, tensor::InsertSliceOp>(op);
+                       });
+      if (backwardSlice.empty() || backwardSlice.front() != sliceOp)
+        return false;
+    }
+  }
+
+  // Check the block arguments, except for the innermost one, have one use.
+  SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
+  return !all_of(bbArgs, [&](BlockArgument bbArg) {
+    return bbArg.hasOneUse() || bbArg == bbArgs.back();
+  });
+}
+
+LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
+                                       ArrayRef<int64_t> tileSizes,
+                                       ArrayRef<int64_t> tileInterchange) {
+  // Exit if all tile sizes are zero.
+  if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
+    return success();
+
+  // Tile the root operation.
+  LinalgTilingOptions tilingOptions;
+  tilingOptions = tilingOptions
+                      .setInterchange(SmallVector<unsigned>(
+                          tileInterchange.begin(), tileInterchange.end()))
+                      .setTileSizes(tileSizes)
+                      .setLoopType(LinalgTilingLoopType::Loops);
+  Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);
+
+  // Replace all uses of the root operation.
+  if (!tiledRootOp.hasValue())
+    return failure();
+  rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
+
+  // Update the root operation and append the loops and tile loop dimensions.
+  rootOp = tiledRootOp->op;
+  loopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
+  for (auto en : enumerate(tileSizes)) {
+    // Copy only the tiled loop dimensions with non-zero tile size.
+    if (en.value() == 0)
+      continue;
+    loopDims.push_back(tileInterchange[en.index()]);
+  }
+  assert(isValid() && "expect tile loop nest to be valid after tiling");
+
+  return success();
+}
+
+FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
+                                               OpOperand *rootOpOperand) {
+  // Check the tile loop nest is non-empty and satisfies all invariants.
+  if (isEmpty() || !isValid())
+    return failure();
+
+  // Check `rootOpOperand` is defined by an ExtractSliceOp.
+  auto sliceOp = rootOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+  if (!sliceOp)
+    return failure();
+
+  // Check `tileLoopNest` tiles `sliceOp` and `rootOpOperand`.
+  if (sliceOp->getParentOp() != rootOp->getParentOp() ||
+      rootOpOperand->getOwner() != rootOp)
+    return failure();
+
+  // Check if the producer is a LinalgOp possibly passed by iteration argument.
+  OpOperand *iterArg = nullptr;
+  auto producerResult = sliceOp.source().dyn_cast<OpResult>();
+  if (auto bbArg = sliceOp.source().dyn_cast<BlockArgument>()) {
+    iterArg = getTiedIterArg(bbArg);
+    // Check the iteration argument may be used to pass in the producer output.
+    if (!iterArg || hasOtherUses(bbArg, sliceOp))
+      return failure();
+    producerResult = iterArg->get().dyn_cast<OpResult>();
+  }
+  if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
+    return failure();
+
+  // TODO: support producers that have index semantics.
+  if (cast<LinalgOp>(producerResult.getOwner()).hasIndexSemantics())
+    return failure();
+
+  // Compute the slice dimensions tiled by `tileLoopNest`.
+  SmallVector<int64_t> tiledSliceDims =
+      getTiledSliceDims(producerResult, rootOpOperand, loopDims);
+  if (tiledSliceDims.empty())
+    return failure();
+
+  // Tile the producer operands and clone the producer in place of `sliceOp`.
+  LinalgOp clonedOp =
+      getTiledProducer(b, producerResult, sliceOp, tiledSliceDims, iterArg);
+
+  // Cast the `clonedOp` result to gap type mismatches before canonicalization.
+  Type consumerOperandType = rootOpOperand->get().getType();
+  Value newResult = clonedOp->getResult(producerResult.getResultNumber());
+  if (newResult.getType() != consumerOperandType) {
+    OpBuilder::InsertionGuard guard(b);
+    b.setInsertionPointAfter(clonedOp);
+    newResult = b.create<tensor::CastOp>(producerResult.getLoc(),
+                                         consumerOperandType, newResult);
+  }
+
+  // Replace the `sliceOp` uses except for the `clonedOp` output uses.
+  sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp);
+  return clonedOp;
+}
+
+//===----------------------------------------------------------------------===//
+// Tile and fuse entry-points.
+//===----------------------------------------------------------------------===//
+
+FailureOr<TileLoopNest>
+mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
+                                           ArrayRef<int64_t> tileSizes,
+                                           ArrayRef<int64_t> tileInterchange) {
+  assert(tileSizes.size() == tileInterchange.size() &&
+         "expect the number of tile sizes and interchange dims to match");
+
+  // Create an empty tile loop nest.
+  TileLoopNest tileLoopNest(consumerOp);
+
+  // Search the number of outer parallel loops to separate them from possible
+  // inner reduction dimensions.
+  SmallVector<StringAttr> iterTypes =
+      llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
+  applyPermutationToVector(
+      iterTypes,
+      SmallVector<unsigned>(tileInterchange.begin(), tileInterchange.end()));
+  auto *it = find_if(iterTypes, [&](StringAttr iterType) {
+    return !isParallelIterator(iterType);
+  });
+  int64_t split = std::distance(iterTypes.begin(), it);
+
+  // Tile the outer parallel loops and fuse the output operands.
+  SmallVector<int64_t> outerTileSizes;
+  outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
+  outerTileSizes.append(tileSizes.size() - split, 0);
+  if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
+    return failure();
+  for (OpOperand *opOperand : tileLoopNest.getRootOp().getOutputOperands())
+    (void)tileLoopNest.fuseProducer(b, opOperand);
+
+  // Tile the remaining loops and fuse the input operands.
+  SmallVector<int64_t> innerTileSizes;
+  innerTileSizes.append(split, 0);
+  innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
+  if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
+    return failure();
+  SmallVector<OpOperand *> inputOperands =
+      tileLoopNest.getRootOp().getInputOperands();
+  for (OpOperand *opOperand : tileLoopNest.getRootOp().getInputOperands())
+    (void)tileLoopNest.fuseProducer(b, opOperand);
+
+  return tileLoopNest;
+}
+
+namespace {
+struct LinalgTileAndFuseTensorOps
+    : public LinalgTileAndFuseTensorOpsBase<LinalgTileAndFuseTensorOps> {
+
+  void notifyFailure(StringRef message) {
+    llvm::errs() << " - LinalgTileAndFuseTensorOps: " << message << "\n";
+    signalPassFailure();
+  }
+
+  void runOnFunction() override {
+    FuncOp funcOp = getFunction();
+    OpBuilder b(funcOp.getContext());
+
+    // Heuristic to find a goor operation to tile and start fusion. Walk all
+    // operations and select the one with the maximal backward slice of fusion
+    // candidates.
+    LinalgOp rootOp = nullptr;
+    int64_t numFusionCandidates = -1;
+    funcOp.walk([&](LinalgOp linalgOp) {
+      SetVector<Operation *> backwardSlice;
+      getBackwardSlice(linalgOp, &backwardSlice);
+      int64_t backwardSliceSize = count_if(
+          backwardSlice, [](Operation *op) { return isa<LinalgOp>(op); });
+      if (backwardSliceSize > numFusionCandidates) {
+        rootOp = linalgOp;
+        numFusionCandidates = backwardSliceSize;
+      }
+    });
+    if (!rootOp)
+      return notifyFailure("expect to find a root operation");
+
+    // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
+    if (tileSizes.size() < rootOp.getNumLoops())
+      return notifyFailure("expect #tile sizes >= #loops");
+
+    // Check `tileInterchange` contains no entries or as many as `tileSizes`.
+    if (!tileInterchange.empty() &&
+        tileInterchange.size() != tileSizes.size()) {
+      return notifyFailure(
+          "expect the number of tile sizes and interchange dims to match");
+    }
+
+    // Copy the `tileSizes` and `tileInterchange` prefixes needed to tile
+    // `rootOp` or use the identity interchange if `tileInterchange` is empty.
+    SmallVector<int64_t> rootTileSizes(
+        tileSizes.begin(), tileSizes.begin() + rootOp.getNumLoops());
+    SmallVector<int64_t> rootInterchange =
+        tileInterchange.empty()
+            ? llvm::to_vector<6>(llvm::seq<int64_t>(0, tileSizes.size()))
+            : SmallVector<int64_t>(tileInterchange.begin(),
+                                   tileInterchange.begin() +
+                                       rootOp.getNumLoops());
+
+    // As a tiling can only tile a loop dimension once, `rootInterchange` has to
+    // be a permutation of the `rootOp` loop dimensions.
+    SmallVector<AffineExpr> rootInterchangeExprs;
+    transform(rootInterchange, std::back_inserter(rootInterchangeExprs),
+              [&](int64_t dim) { return b.getAffineDimExpr(dim); });
+    AffineMap rootInterchangeMap = AffineMap::get(
+        rootOp.getNumLoops(), 0, rootInterchangeExprs, funcOp.getContext());
+    if (!rootInterchangeMap.isPermutation())
+      return notifyFailure(
+          "expect the tile interchange permutes the root loops");
+
+    // Tile `rootOp` and fuse its producers.
+    if (failed(tileConsumerAndFuseProducers(b, rootOp, rootTileSizes,
+                                            rootInterchange)))
+      return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly");
+  }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgTileAndFuseTensorOpsPass() {
+  return std::make_unique<LinalgTileAndFuseTensorOps>();
+}

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
new file mode 100644
index 000000000000..4ead2391e4ef
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
@@ -0,0 +1,190 @@
+// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=5,4,7 tile-interchange=1,0,2" -cse -split-input-file | FileCheck %s
+
+//  CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
+//  CHECK-DAG:  #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
+//  CHECK-DAG:  #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 24)>
+//  CHECK-DAG:  #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
+
+//      CHECK:  fuse_input
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+builtin.func @fuse_input(%arg0: tensor<24x12xf32>,
+                         %arg1: tensor<12x25xf32>,
+                         %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  %c0 = constant 0 : index
+  %c12 = constant 12 : index
+  %c25 = constant 25 : index
+  %c24 = constant 24 : index
+  %c4 = constant 4 : index
+  %cst = constant 0.000000e+00 : f32
+  %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32>
+
+  //      CHECK:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
+  //      CHECK:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
+  //      CHECK:      %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
+  //      CHECK:      scf.for %[[IV2:[0-9a-zA-Z]*]] =
+  //      CHECK:        %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
+
+  // Tile both input operand dimensions.
+  //      CHECK:        %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]])
+  //      CHECK:        %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
+  //      CHECK:        %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+  // CHECK-SAME:                                          %[[IV1]], %[[IV2]]
+  // CHECK-SAME:                                          %[[UB1]], %[[UB2]]
+  //      CHECK:        %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //      CHECK:        %{{.*}} = linalg.matmul ins(%[[T1]]
+  %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  return %1 : tensor<24x25xf32>
+}
+
+// -----
+
+//  CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
+//  CHECK-DAG:  #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
+
+//      CHECK:  fuse_output
+// CHECK-SAME:    %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+builtin.func @fuse_output(%arg0: tensor<24x12xf32>,
+                          %arg1: tensor<12x25xf32>,
+                          %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  %c0 = constant 0 : index
+  %c12 = constant 12 : index
+  %c25 = constant 25 : index
+  %c24 = constant 24 : index
+  %c4 = constant 4 : index
+  %cst = constant 0.000000e+00 : f32
+  %0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32>
+
+  // Update the iteration argument of the outermost tile loop.
+  //      CHECK:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
+  //      CHECK:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
+  //      CHECK:      %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
+  //      CHECK:      %[[TS0:.*]] = affine.min #[[MAP1]](%[[IV0]])
+
+  // Tile the both output operand dimensions.
+  //      CHECK:      %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
+  // CHECK-SAME:                                        %[[IV1]], %[[IV0]]
+  // CHECK-SAME:                                        %[[TS1]], %[[TS0]]
+  //      CHECK:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //      CHECK:        scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
+  //      CHECK:          %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]]
+  %1 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%0 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  return %1 : tensor<24x25xf32>
+}
+
+// -----
+
+//  CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
+//  CHECK-DAG:  #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
+//  CHECK-DAG:  #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 25)>
+//  CHECK-DAG:  #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+
+//      CHECK:  fuse_reduction
+// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
+// CHECK-SAME:    %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32>
+builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>,
+                             %arg1: tensor<12x25xf32>,
+                             %arg2: tensor<24x25xf32>,
+                             %arg3: tensor<12x7x25xf32>) -> tensor<24x25xf32> {
+  %c0 = constant 0 : index
+  %c12 = constant 12 : index
+  %c25 = constant 25 : index
+  %c24 = constant 24 : index
+  %c4 = constant 4 : index
+  %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg3 : tensor<12x7x25xf32>) outs(%arg1 : tensor<12x25xf32>) {
+  ^bb0(%arg4: f32, %arg5: f32):  // no predecessors
+    %2 = addf %arg4, %arg5 : f32
+    linalg.yield %2 : f32
+  } -> tensor<12x25xf32>
+
+  //      CHECK:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
+  //      CHECK:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
+  //      CHECK:      %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]])
+  //      CHECK:      scf.for %[[IV2:[0-9a-zA-Z]*]] =
+  //      CHECK:        %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
+  //      CHECK:        %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
+  //      CHECK:        %[[UB0:.*]] = affine.min #[[MAP2]](%[[TS0]], %[[IV0]])
+
+  // Tile only the parallel dimensions but not the reduction dimension.
+  //      CHECK:        %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
+  // CHECK-SAME:                                          %[[IV2]], 0, %[[IV0]]
+  // CHECK-SAME:                                          %[[UB2]], 7, %[[UB0]]
+  //      CHECK:        %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+  // CHECK-SAME:                                          %[[IV2]], %[[IV0]]
+  // CHECK-SAME:                                          %[[UB2]], %[[UB0]]
+  //      CHECK:        %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
+  //      CHECK:        %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]]
+  %1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  return %1 : tensor<24x25xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+
+//      CHECK:  fuse_transposed
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// CHECK-SAME:    %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x24xf32>
+builtin.func @fuse_transposed(%arg0: tensor<24x12xf32>,
+                              %arg1: tensor<12x25xf32>,
+                              %arg2: tensor<24x25xf32>,
+                              %arg3: tensor<12x24xf32>) -> tensor<24x25xf32> {
+  %c0 = constant 0 : index
+  %c12 = constant 12 : index
+  %c25 = constant 25 : index
+  %c24 = constant 24 : index
+  %c4 = constant 4 : index
+  %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg3 : tensor<12x24xf32>) outs(%arg0 : tensor<24x12xf32>) {
+  ^bb0(%arg4: f32, %arg5: f32):  // no predecessors
+    %2 = addf %arg4, %arg5 : f32
+    linalg.yield %2 : f32
+  } -> tensor<24x12xf32>
+
+  //      CHECK:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
+  //      CHECK:    scf.for %[[IV1:[0-9a-zA-Z]*]] =
+  //      CHECK:      scf.for %[[IV2:[0-9a-zA-Z]*]] =
+
+  // Swap the input operand slice offsets due to the transposed indexing map.
+  //      CHECK:        %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
+  // CHECK-SAME:                                          %[[IV2]], %[[IV1]]
+  //      CHECK:        %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
+  // CHECK-SAME:                                          %[[IV1]], %[[IV2]]
+  //      CHECK:        %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
+  //      CHECK:        %{{.*}} = linalg.matmul ins(%[[T2]]
+  %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  return %1 : tensor<24x25xf32>
+}
+
+// -----
+
+//      CHECK:  fuse_input_and_output
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// CHECK-SAME:    %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>,
+                                    %arg1: tensor<12x25xf32>,
+                                    %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  %c0 = constant 0 : index
+  %c12 = constant 12 : index
+  %c25 = constant 25 : index
+  %c24 = constant 24 : index
+  %c4 = constant 4 : index
+  %cst = constant 0.000000e+00 : f32
+  %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32>
+  %1 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32>
+
+  // Fuse both producers to the appropriate tile loops.
+  //      CHECK:  scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
+  //      CHECK:    scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
+  //      CHECK:      %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
+  // CHECK-SAME:                                        %[[IV1]], %[[IV0]]
+  //      CHECK:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+  //      CHECK:        scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
+  //      CHECK:          %[[T2:.*]] = tensor.extract_slice %[[ARG0]]
+  // CHECK-SAME:                                            %[[IV1]], %[[IV2]]
+  //      CHECK:          %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
+  //      CHECK:          %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]]
+  %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  return %2 : tensor<24x25xf32>
+}


        


More information about the Mlir-commits mailing list