[Mlir-commits] [mlir] 6db928b - [mlir][linalg] Fusion on tensors.
Tobias Gysi
llvmlistbot at llvm.org
Mon Sep 20 07:46:06 PDT 2021
Author: Tobias Gysi
Date: 2021-09-20T14:45:34Z
New Revision: 6db928b8f31b17caf205eee9c95bb817e51a3f2c
URL: https://github.com/llvm/llvm-project/commit/6db928b8f31b17caf205eee9c95bb817e51a3f2c
DIFF: https://github.com/llvm/llvm-project/commit/6db928b8f31b17caf205eee9c95bb817e51a3f2c.diff
LOG: [mlir][linalg] Fusion on tensors.
Add a new version of fusion on tensors that supports the following scenarios:
- support input and output operand fusion
- fuse a producer result passed in via tile loop iteration arguments (update the tile loop iteration arguments)
- supports only linalg operations on tensors
- supports only scf::for
- cannot add an output to the tile loop nest
The LinalgTileAndFuseOnTensors pass tiles the root operation and fuses its producers.
Reviewed By: nicolasvasilache, mravishankar
Differential Revision: https://reviews.llvm.org/D109766
Added:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 12b55a153d16..063bf71d0523 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -77,6 +77,9 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
/// work on primitive types, if possible.
std::unique_ptr<Pass> createLinalgDetensorizePass();
+/// Create a pass to tile a LinalgOp and fuse its producers.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFuseTensorOpsPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index ecde91ff5120..acd18ff7977f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -243,4 +243,18 @@ def LinalgDetensorize : FunctionPass<"linalg-detensorize"> {
}];
}
+def LinalgTileAndFuseTensorOps
+ : FunctionPass<"linalg-tile-and-fuse-tensor-ops"> {
+ let summary = "Tile a LinalgOp and fuse its producers.";
+ let constructor = "mlir::createLinalgTileAndFuseTensorOpsPass()";
+ let options = [
+ ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes",
+ "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+ ListOption<"tileInterchange", "tile-interchange", "int64_t",
+ "Tile loop interchange",
+ "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
+ ];
+ let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 8776b7404542..a8d4edf6c072 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -172,6 +172,64 @@ Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
OpResult producerOpResult,
OpOperand &consumerOpOperand);
+//===----------------------------------------------------------------------===//
+// Fusion on tensor utilities
+//===----------------------------------------------------------------------===//
+
+/// A struct to manage the tile loop nest specific information.
+class TileLoopNest {
+public:
+ TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
+
+ /// Tile the root operation using the given `tileSizes` and `tileInterchange`.
+ LogicalResult tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> tileInterchange);
+
+ /// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the
+ /// fused producer of fails if fusion is not possible.
+ // TODO: add replace uses callback to support passes and patterns.
+ FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *rootOpOperand);
+
+ /// Returns the tiled root operation.
+ LinalgOp getRootOp() { return rootOp; }
+
+private:
+ /// Returns true if the tile loop nest has no tile loops.
+ bool isEmpty();
+
+ /// Returns true if the tile loop nest invariants are satisfied:
+ /// - The number of tile loop operations and dimensions match.
+ /// - The innermost tile loop is the parent of `tiledOp`.
+ /// - The tile loops are directly nested.
+ // TODO: relax to support additional control flow, e.g., IfOp.
+ bool isValid();
+
+ /// Searches the block arguments tied to a block argument `bbArg` of the
+ /// innermost tile loop. Returns the block argument from outermost to
+ /// innermost or an empty vector if none are found.
+ SmallVector<BlockArgument> getTiedBBArgs(BlockArgument bbArg);
+
+ /// Returns the iteration argument of the outermost tile loop mapped to a
+ /// block argument `bbArg` of the innermost tile loop.
+ OpOperand *getTiedIterArg(BlockArgument bbArg);
+
+ /// Returns true if `bbArg` has other used than `sliceOp` and its
+ /// dependencies. Only if there are no other uses, the producer output
+ /// iteration argument may reused to pass the producer result after fusion.
+ bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
+
+ LinalgOp rootOp;
+ SmallVector<scf::ForOp> loopOps;
+ SmallVector<int64_t> loopDims;
+};
+
+/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
+/// `tileSizes` and `tileInterchange` parameters to control the tiling.
+FailureOr<TileLoopNest>
+tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> tileInterchange);
+
//===----------------------------------------------------------------------===//
// Distribution utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ebc6443d683c..bd4487b62253 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
ElementwiseOpFusion.cpp
ElementwiseToLinalg.cpp
Fusion.cpp
+ FusionOnTensors.cpp
Generalization.cpp
Hoisting.cpp
InlineScalarOperands.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
new file mode 100644
index 000000000000..d3069ee40ca7
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -0,0 +1,481 @@
+//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements linalg fusion on tensors
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace linalg;
+
+//===----------------------------------------------------------------------===//
+// StructuredOp specific helpers.
+//===----------------------------------------------------------------------===//
+
+/// Relate the producer to the consumer loop iterations that access the same
+/// producer result element:
+/// consumerToProducerLoops =
+/// inverse(producerIndexingMap).compose(consumerIndexingMap).
+/// Return `consumerToProducerLoops` or none if the inversion fails.
+static Optional<AffineMap>
+getConsumerToProducerLoopsMap(AffineMap producerIndexingMap,
+ AffineMap consumerIndexingMap) {
+ assert(consumerIndexingMap.getNumResults() ==
+ producerIndexingMap.getNumResults() &&
+ "expect the number of indexing map results to match");
+ // Ensure the producer indexing map is a projected permutation.
+ if (!producerIndexingMap.isProjectedPermutation())
+ return None;
+ AffineMap inverseIndexingMap =
+ inverseAndBroadcastProjectedPermuation(producerIndexingMap);
+ return inverseIndexingMap.compose(consumerIndexingMap);
+}
+
+/// Returns the producer result slice dimensions tiled by the tile loop nest or
+/// an empty vector if `getConsumerToProducerLoopsMap` returns none.
+// TODO: replace by Fourier-Motzkin and/or compute starting from consumer.
+SmallVector<int64_t> getTiledSliceDims(OpResult producerResult,
+ OpOperand *consumerOperand,
+ ArrayRef<int64_t> tiledLoopDims) {
+ LinalgOp consumerOp = consumerOperand->getOwner();
+ LinalgOp producerOp = producerResult.getOwner();
+ OpOperand *opOperand =
+ producerOp.getOutputOperand(producerResult.getResultNumber());
+
+ // Compute the `consumerToProducerLoopsMap` and exit if the computation fails.
+ AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(opOperand);
+ Optional<AffineMap> consumerToProducerLoopsMap =
+ getConsumerToProducerLoopsMap(
+ producerIndexingMap, consumerOp.getTiedIndexingMap(consumerOperand));
+ if (!consumerToProducerLoopsMap.hasValue())
+ return {};
+
+ // Compute the set of tiled producer loops.
+ DenseSet<int64_t> tiledProducerLoops;
+ for (auto en : enumerate(consumerToProducerLoopsMap->getResults())) {
+ for (int64_t dim : tiledLoopDims) {
+ if (en.value().isFunctionOfDim(dim))
+ tiledProducerLoops.insert(en.index());
+ }
+ }
+
+ // Compute the slice dimensions for the tiled producer loops.
+ SmallVector<int64_t> tiledSliceDims;
+ for (auto en : enumerate(producerIndexingMap.getResults())) {
+ auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
+ if (dimExpr && tiledProducerLoops.count(dimExpr.getPosition()) != 0)
+ tiledSliceDims.push_back(en.index());
+ }
+ return tiledSliceDims;
+}
+
+/// Returns the producer fused in place of `sliceOp`. Tile the producer operands
+/// along the `tiledSliceDims` and clone the producer. Consider the case of
+/// fusion of an output tensor:
+/// ```
+/// %1 = producer ins(...) outs(%0)
+/// %2 = consumer ins(...) outs(%1)
+/// ```
+/// When consumer is tiled, %1 appears in the loop iter_args:
+/// ```
+/// %1 = producer ins(...) outs(%0)
+/// %2 = scf.for ... iter_args(%1) .. (%bbarg) {
+/// %t1 = tensor.extract_slice %bbarg[..]
+/// %t2 = consumer ins(...) outs(%t1)
+/// %r = tensor.insert_slice %t2, %bbarg[...]
+/// }
+/// ```
+/// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0):
+/// ```
+/// %2 = scf.for ... iter_args(%0) .. (%bbarg) {
+/// %t0 = tensor.extract_slice %bbarg[..]
+/// %t1 = producer ins(...) outs(%t0)
+/// %t2 = consumer ins(...) outs(%t1)
+/// %r = tensor.insert_slice %t2, %bbarg[...]
+/// }
+/// ```
+/// This transformation is only valid if %bbarg is exclusively used by the
+/// output ExtractSliceOp / InsertSliceOp pair, which is checked by the
+/// `fuseProducer` method.
+/// TODO: instead of check and failure, insert new iter_args each time a
+/// producer is fused into a consumer and fold away unused iter_args.
+static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
+ tensor::ExtractSliceOp sliceOp,
+ ArrayRef<int64_t> tiledSliceDims,
+ OpOperand *iterArg) {
+ // Clone the producer after `sliceOp` since the slice may be reused to pass in
+ // the producer result.
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointAfter(sliceOp);
+
+ // Get the producer.
+ LinalgOp producerOp = producerResult.getOwner();
+ Location loc = producerOp.getLoc();
+
+ // Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
+ SmallVector<Value> producerLoopBounds;
+ transform(producerOp.createLoopRanges(b, loc),
+ std::back_inserter(producerLoopBounds),
+ [](Range range) { return range.size; });
+ SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
+
+ // Get the producer result indexing map.
+ AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
+ producerOp.getOutputOperand(producerResult.getResultNumber()));
+
+ // Tile the producer operands given the `sliceOp` ranges. Iterate the
+ // `tiledSliceDims` and store the tile offset and size for the tiled slice
+ // dimension. Assumes the mapping from slice dimensions to producer loops is a
+ // permutation.
+ auto zero = b.create<ConstantIndexOp>(loc, 0);
+ SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr);
+ SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero);
+ SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr);
+ for (int64_t tiledSliceDim : tiledSliceDims) {
+ AffineExpr result = producerIndexingMap.getResults()[tiledSliceDim];
+ assert(result.isa<AffineDimExpr>() &&
+ "expect producer indexing map is a projected permutation");
+ int64_t tiledProducerLoop = result.cast<AffineDimExpr>().getPosition();
+ tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset;
+ tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size;
+ allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
+ }
+ erase_value(tileIvs, nullptr);
+ SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
+ tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
+ tileSizes, producerLoopBounds);
+
+ // Output fusion has to update the iteration arguments of the tile loop nest.
+ // In particular, the iteration argument of the outermost tile loop needs to
+ // be set to the producer output instead of the producer result and `clonedOp`
+ // shall use the existing `sliceOp` result instead of the tiled producer
+ // output operand.
+ if (iterArg) {
+ OpOperand *outputOperand =
+ producerOp.getOutputOperand(producerResult.getResultNumber());
+ iterArg->set(outputOperand->get());
+ tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult();
+ }
+
+ // Clone the producer using the tiled producer operands.
+ TypeRange resultTypes = ValueRange(tiledOperands)
+ .take_back(producerOp.getNumOutputs())
+ .getTypes();
+ LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
+
+ return clonedOp;
+}
+
+//===----------------------------------------------------------------------===//
+// TileLoopNest specific helpers.
+//===----------------------------------------------------------------------===//
+
+bool TileLoopNest::isEmpty() { return loopOps.empty(); }
+
+bool TileLoopNest::isValid() {
+ // Check if the number of `tileLoopOps` and `tileLoopDims` match.
+ if (loopOps.size() != loopDims.size())
+ return false;
+
+ // Check if the innermost tile loop is the parent of `tiledOp`.
+ if (rootOp->getParentOp() != loopOps.back())
+ return false;
+
+ // Check if the tile loops are directly nested.
+ return std::adjacent_find(loopOps.begin(), loopOps.end(),
+ [](Operation *op1, Operation *op2) {
+ return op1 != op2->getParentOp();
+ }) == loopOps.end();
+}
+
+SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
+ assert(bbArg && "expect the block argument to be non-zero");
+ SmallVector<BlockArgument> bbArgs;
+
+ // Search all tile loop block arguments from inner to outer.
+ for (auto tileLoop : reverse(loopOps)) {
+ if (bbArg.getOwner()->getParentOp() != tileLoop)
+ return {};
+ bbArgs.push_back(bbArg);
+ OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg);
+ bbArg = iterArg->get().dyn_cast<BlockArgument>();
+ }
+
+ // Reverse the block arguments to order them from outer to inner.
+ return {bbArgs.rbegin(), bbArgs.rend()};
+}
+
+OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) {
+ // Search all block arguments and return the matching iteration argument.
+ SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
+ if (bbArgs.size() != loopOps.size())
+ return nullptr;
+ return &loopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
+}
+
+bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
+ tensor::ExtractSliceOp sliceOp) {
+ // Check the innermost block argument is either used by the ExtractSliceOp
+ // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses
+ // conservatively.
+ for (Operation *op : bbArg.getUsers()) {
+ if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op))
+ return false;
+ if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
+ if (extractSliceOp != sliceOp)
+ return false;
+ }
+ if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
+ SetVector<Operation *> backwardSlice;
+ getBackwardSlice(insertSliceOp.source(), &backwardSlice,
+ [](Operation *op) {
+ return isa<LinalgOp, tensor::InsertSliceOp>(op);
+ });
+ if (backwardSlice.empty() || backwardSlice.front() != sliceOp)
+ return false;
+ }
+ }
+
+ // Check the block arguments, except for the innermost one, have one use.
+ SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
+ return !all_of(bbArgs, [&](BlockArgument bbArg) {
+ return bbArg.hasOneUse() || bbArg == bbArgs.back();
+ });
+}
+
+LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> tileInterchange) {
+ // Exit if all tile sizes are zero.
+ if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
+ return success();
+
+ // Tile the root operation.
+ LinalgTilingOptions tilingOptions;
+ tilingOptions = tilingOptions
+ .setInterchange(SmallVector<unsigned>(
+ tileInterchange.begin(), tileInterchange.end()))
+ .setTileSizes(tileSizes)
+ .setLoopType(LinalgTilingLoopType::Loops);
+ Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);
+
+ // Replace all uses of the root operation.
+ if (!tiledRootOp.hasValue())
+ return failure();
+ rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
+
+ // Update the root operation and append the loops and tile loop dimensions.
+ rootOp = tiledRootOp->op;
+ loopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
+ for (auto en : enumerate(tileSizes)) {
+ // Copy only the tiled loop dimensions with non-zero tile size.
+ if (en.value() == 0)
+ continue;
+ loopDims.push_back(tileInterchange[en.index()]);
+ }
+ assert(isValid() && "expect tile loop nest to be valid after tiling");
+
+ return success();
+}
+
+FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
+ OpOperand *rootOpOperand) {
+ // Check the tile loop nest is non-empty and satisfies all invariants.
+ if (isEmpty() || !isValid())
+ return failure();
+
+ // Check `rootOpOperand` is defined by an ExtractSliceOp.
+ auto sliceOp = rootOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!sliceOp)
+ return failure();
+
+ // Check `tileLoopNest` tiles `sliceOp` and `rootOpOperand`.
+ if (sliceOp->getParentOp() != rootOp->getParentOp() ||
+ rootOpOperand->getOwner() != rootOp)
+ return failure();
+
+ // Check if the producer is a LinalgOp possibly passed by iteration argument.
+ OpOperand *iterArg = nullptr;
+ auto producerResult = sliceOp.source().dyn_cast<OpResult>();
+ if (auto bbArg = sliceOp.source().dyn_cast<BlockArgument>()) {
+ iterArg = getTiedIterArg(bbArg);
+ // Check the iteration argument may be used to pass in the producer output.
+ if (!iterArg || hasOtherUses(bbArg, sliceOp))
+ return failure();
+ producerResult = iterArg->get().dyn_cast<OpResult>();
+ }
+ if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
+ return failure();
+
+ // TODO: support producers that have index semantics.
+ if (cast<LinalgOp>(producerResult.getOwner()).hasIndexSemantics())
+ return failure();
+
+ // Compute the slice dimensions tiled by `tileLoopNest`.
+ SmallVector<int64_t> tiledSliceDims =
+ getTiledSliceDims(producerResult, rootOpOperand, loopDims);
+ if (tiledSliceDims.empty())
+ return failure();
+
+ // Tile the producer operands and clone the producer in place of `sliceOp`.
+ LinalgOp clonedOp =
+ getTiledProducer(b, producerResult, sliceOp, tiledSliceDims, iterArg);
+
+ // Cast the `clonedOp` result to gap type mismatches before canonicalization.
+ Type consumerOperandType = rootOpOperand->get().getType();
+ Value newResult = clonedOp->getResult(producerResult.getResultNumber());
+ if (newResult.getType() != consumerOperandType) {
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointAfter(clonedOp);
+ newResult = b.create<tensor::CastOp>(producerResult.getLoc(),
+ consumerOperandType, newResult);
+ }
+
+ // Replace the `sliceOp` uses except for the `clonedOp` output uses.
+ sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp);
+ return clonedOp;
+}
+
+//===----------------------------------------------------------------------===//
+// Tile and fuse entry-points.
+//===----------------------------------------------------------------------===//
+
+FailureOr<TileLoopNest>
+mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> tileInterchange) {
+ assert(tileSizes.size() == tileInterchange.size() &&
+ "expect the number of tile sizes and interchange dims to match");
+
+ // Create an empty tile loop nest.
+ TileLoopNest tileLoopNest(consumerOp);
+
+ // Search the number of outer parallel loops to separate them from possible
+ // inner reduction dimensions.
+ SmallVector<StringAttr> iterTypes =
+ llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
+ applyPermutationToVector(
+ iterTypes,
+ SmallVector<unsigned>(tileInterchange.begin(), tileInterchange.end()));
+ auto *it = find_if(iterTypes, [&](StringAttr iterType) {
+ return !isParallelIterator(iterType);
+ });
+ int64_t split = std::distance(iterTypes.begin(), it);
+
+ // Tile the outer parallel loops and fuse the output operands.
+ SmallVector<int64_t> outerTileSizes;
+ outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
+ outerTileSizes.append(tileSizes.size() - split, 0);
+ if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
+ return failure();
+ for (OpOperand *opOperand : tileLoopNest.getRootOp().getOutputOperands())
+ (void)tileLoopNest.fuseProducer(b, opOperand);
+
+ // Tile the remaining loops and fuse the input operands.
+ SmallVector<int64_t> innerTileSizes;
+ innerTileSizes.append(split, 0);
+ innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
+ if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
+ return failure();
+ SmallVector<OpOperand *> inputOperands =
+ tileLoopNest.getRootOp().getInputOperands();
+ for (OpOperand *opOperand : tileLoopNest.getRootOp().getInputOperands())
+ (void)tileLoopNest.fuseProducer(b, opOperand);
+
+ return tileLoopNest;
+}
+
+namespace {
+struct LinalgTileAndFuseTensorOps
+ : public LinalgTileAndFuseTensorOpsBase<LinalgTileAndFuseTensorOps> {
+
+ void notifyFailure(StringRef message) {
+ llvm::errs() << " - LinalgTileAndFuseTensorOps: " << message << "\n";
+ signalPassFailure();
+ }
+
+ void runOnFunction() override {
+ FuncOp funcOp = getFunction();
+ OpBuilder b(funcOp.getContext());
+
+ // Heuristic to find a goor operation to tile and start fusion. Walk all
+ // operations and select the one with the maximal backward slice of fusion
+ // candidates.
+ LinalgOp rootOp = nullptr;
+ int64_t numFusionCandidates = -1;
+ funcOp.walk([&](LinalgOp linalgOp) {
+ SetVector<Operation *> backwardSlice;
+ getBackwardSlice(linalgOp, &backwardSlice);
+ int64_t backwardSliceSize = count_if(
+ backwardSlice, [](Operation *op) { return isa<LinalgOp>(op); });
+ if (backwardSliceSize > numFusionCandidates) {
+ rootOp = linalgOp;
+ numFusionCandidates = backwardSliceSize;
+ }
+ });
+ if (!rootOp)
+ return notifyFailure("expect to find a root operation");
+
+ // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
+ if (tileSizes.size() < rootOp.getNumLoops())
+ return notifyFailure("expect #tile sizes >= #loops");
+
+ // Check `tileInterchange` contains no entries or as many as `tileSizes`.
+ if (!tileInterchange.empty() &&
+ tileInterchange.size() != tileSizes.size()) {
+ return notifyFailure(
+ "expect the number of tile sizes and interchange dims to match");
+ }
+
+ // Copy the `tileSizes` and `tileInterchange` prefixes needed to tile
+ // `rootOp` or use the identity interchange if `tileInterchange` is empty.
+ SmallVector<int64_t> rootTileSizes(
+ tileSizes.begin(), tileSizes.begin() + rootOp.getNumLoops());
+ SmallVector<int64_t> rootInterchange =
+ tileInterchange.empty()
+ ? llvm::to_vector<6>(llvm::seq<int64_t>(0, tileSizes.size()))
+ : SmallVector<int64_t>(tileInterchange.begin(),
+ tileInterchange.begin() +
+ rootOp.getNumLoops());
+
+ // As a tiling can only tile a loop dimension once, `rootInterchange` has to
+ // be a permutation of the `rootOp` loop dimensions.
+ SmallVector<AffineExpr> rootInterchangeExprs;
+ transform(rootInterchange, std::back_inserter(rootInterchangeExprs),
+ [&](int64_t dim) { return b.getAffineDimExpr(dim); });
+ AffineMap rootInterchangeMap = AffineMap::get(
+ rootOp.getNumLoops(), 0, rootInterchangeExprs, funcOp.getContext());
+ if (!rootInterchangeMap.isPermutation())
+ return notifyFailure(
+ "expect the tile interchange permutes the root loops");
+
+ // Tile `rootOp` and fuse its producers.
+ if (failed(tileConsumerAndFuseProducers(b, rootOp, rootTileSizes,
+ rootInterchange)))
+ return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly");
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgTileAndFuseTensorOpsPass() {
+ return std::make_unique<LinalgTileAndFuseTensorOps>();
+}
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
new file mode 100644
index 000000000000..4ead2391e4ef
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
@@ -0,0 +1,190 @@
+// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=5,4,7 tile-interchange=1,0,2" -cse -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 24)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
+
+// CHECK: fuse_input
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+builtin.func @fuse_input(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ %c0 = constant 0 : index
+ %c12 = constant 12 : index
+ %c25 = constant 25 : index
+ %c24 = constant 24 : index
+ %c4 = constant 4 : index
+ %cst = constant 0.000000e+00 : f32
+ %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32>
+
+ // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
+ // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] =
+ // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
+
+ // Tile both input operand dimensions.
+ // CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]])
+ // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+ // CHECK-SAME: %[[IV1]], %[[IV2]]
+ // CHECK-SAME: %[[UB1]], %[[UB2]]
+ // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+ // CHECK: %{{.*}} = linalg.matmul ins(%[[T1]]
+ %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+ return %1 : tensor<24x25xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
+
+// CHECK: fuse_output
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+builtin.func @fuse_output(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ %c0 = constant 0 : index
+ %c12 = constant 12 : index
+ %c25 = constant 25 : index
+ %c24 = constant 24 : index
+ %c4 = constant 4 : index
+ %cst = constant 0.000000e+00 : f32
+ %0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32>
+
+ // Update the iteration argument of the outermost tile loop.
+ // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
+ // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
+ // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
+ // CHECK: %[[TS0:.*]] = affine.min #[[MAP1]](%[[IV0]])
+
+ // Tile the both output operand dimensions.
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
+ // CHECK-SAME: %[[IV1]], %[[IV0]]
+ // CHECK-SAME: %[[TS1]], %[[TS0]]
+ // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+ // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
+ // CHECK: %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]]
+ %1 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%0 : tensor<24x25xf32>) -> tensor<24x25xf32>
+ return %1 : tensor<24x25xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 25)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+
+// CHECK: fuse_reduction
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
+// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32>
+builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>,
+ %arg3: tensor<12x7x25xf32>) -> tensor<24x25xf32> {
+ %c0 = constant 0 : index
+ %c12 = constant 12 : index
+ %c25 = constant 25 : index
+ %c24 = constant 24 : index
+ %c4 = constant 4 : index
+ %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg3 : tensor<12x7x25xf32>) outs(%arg1 : tensor<12x25xf32>) {
+ ^bb0(%arg4: f32, %arg5: f32): // no predecessors
+ %2 = addf %arg4, %arg5 : f32
+ linalg.yield %2 : f32
+ } -> tensor<12x25xf32>
+
+ // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ // CHECK: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]])
+ // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] =
+ // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
+ // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
+ // CHECK: %[[UB0:.*]] = affine.min #[[MAP2]](%[[TS0]], %[[IV0]])
+
+ // Tile only the parallel dimensions but not the reduction dimension.
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
+ // CHECK-SAME: %[[IV2]], 0, %[[IV0]]
+ // CHECK-SAME: %[[UB2]], 7, %[[UB0]]
+ // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+ // CHECK-SAME: %[[IV2]], %[[IV0]]
+ // CHECK-SAME: %[[UB2]], %[[UB0]]
+ // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
+ // CHECK: %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]]
+ %1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+ return %1 : tensor<24x25xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK: fuse_transposed
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x24xf32>
+builtin.func @fuse_transposed(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>,
+ %arg3: tensor<12x24xf32>) -> tensor<24x25xf32> {
+ %c0 = constant 0 : index
+ %c12 = constant 12 : index
+ %c25 = constant 25 : index
+ %c24 = constant 24 : index
+ %c4 = constant 4 : index
+ %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg3 : tensor<12x24xf32>) outs(%arg0 : tensor<24x12xf32>) {
+ ^bb0(%arg4: f32, %arg5: f32): // no predecessors
+ %2 = addf %arg4, %arg5 : f32
+ linalg.yield %2 : f32
+ } -> tensor<24x12xf32>
+
+ // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] =
+
+ // Swap the input operand slice offsets due to the transposed indexing map.
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
+ // CHECK-SAME: %[[IV2]], %[[IV1]]
+ // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
+ // CHECK-SAME: %[[IV1]], %[[IV2]]
+ // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
+ // CHECK: %{{.*}} = linalg.matmul ins(%[[T2]]
+ %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+ return %1 : tensor<24x25xf32>
+}
+
+// -----
+
+// CHECK: fuse_input_and_output
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ %c0 = constant 0 : index
+ %c12 = constant 12 : index
+ %c25 = constant 25 : index
+ %c24 = constant 24 : index
+ %c4 = constant 4 : index
+ %cst = constant 0.000000e+00 : f32
+ %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32>
+ %1 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32>
+
+ // Fuse both producers to the appropriate tile loops.
+ // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
+ // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
+ // CHECK-SAME: %[[IV1]], %[[IV0]]
+ // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+ // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
+ // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG0]]
+ // CHECK-SAME: %[[IV1]], %[[IV2]]
+ // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
+ // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]]
+ %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32>
+ return %2 : tensor<24x25xf32>
+}
More information about the Mlir-commits
mailing list