[Mlir-commits] [mlir] c6ea095 - [mlir][Linalg] NFC : Move fusion on tensors to separate file.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 1 09:50:58 PDT 2020
Author: MaheshRavishankar
Date: 2020-10-01T09:50:37-07:00
New Revision: c6ea095b9756dff035aed27e7b5b44bf42d22462
URL: https://github.com/llvm/llvm-project/commit/c6ea095b9756dff035aed27e7b5b44bf42d22462
DIFF: https://github.com/llvm/llvm-project/commit/c6ea095b9756dff035aed27e7b5b44bf42d22462.diff
LOG: [mlir][Linalg] NFC : Move fusion on tensors to separate file.
Differential Revision: https://reviews.llvm.org/D88633
Added:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Modified:
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a281aa55a44f..2b137175d174 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRLinalgTransforms
DropUnitDims.cpp
Fusion.cpp
+ FusionOnTensors.cpp
Hoisting.cpp
Interchange.cpp
Loops.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 8dadfe63e659..c964c2466d5c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -736,687 +736,12 @@ static void fuseLinalgOpsGreedily(FuncOp f) {
LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
}
-//====---------------------------------------------------------------------===//
-// Fusion on Tensor operation.
-//====---------------------------------------------------------------------===//
-
-namespace {
-
-/// Implementation of fusion of generic ops and indexed_generic ops.
-struct FuseGenericOpsOnTensors {
- static bool isFusible(LinalgOp producer, LinalgOp consumer,
- unsigned consumerIdx) {
- // Producer and consumer must have tensor semantics.
- if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
- return false;
-
- // Verify that
- // - the producer has all "parallel" iterator type.
- if (producer.getNumParallelLoops() != producer.getNumLoops())
- return false;
-
- // Get the consumer index map. The number of results of the consumer index
- // map must match the number of loops of the producer.
- AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
- if (consumerIndexMap.getNumResults() != producer.getNumLoops())
- return false;
-
- // Finally the index_map for the result must be invertible. For now just
- // verify it is a permutation.
- AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
- return producerResultIndexMap.isPermutation();
- }
-
- static LinalgOp fuse(LinalgOp producer, LinalgOp consumer,
- unsigned consumerIdx, PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
- if (!isFusible(producer, consumer, consumerIdx))
- return nullptr;
-
- unsigned numFusedOperands = producer.getOperation()->getNumOperands() +
- consumer.getOperation()->getNumOperands() - 1;
-
- // Compute the fused operands list,
- SmallVector<Value, 2> fusedOperands;
- fusedOperands.reserve(numFusedOperands);
- auto consumerOperands = consumer.getOperation()->getOperands();
- auto producerOperands = producer.getOperation()->getOperands();
- fusedOperands.assign(consumerOperands.begin(),
- std::next(consumerOperands.begin(), consumerIdx));
- fusedOperands.append(producerOperands.begin(), producerOperands.end());
- fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
- consumerOperands.end());
-
- // Compute indexing_maps for the fused operation. The indexing_maps for the
- // operands of the consumers that arent fused are the same. The
- // indexing_maps for the producers need to be computed based on the
- // indexing_map of the operand at consumerIdx in the consumer.
- SmallVector<Attribute, 4> fusedIndexMaps;
- auto consumerIndexMaps = consumer.indexing_maps();
- fusedIndexMaps.reserve(fusedOperands.size() +
- consumer.getOperation()->getNumResults());
- fusedIndexMaps.assign(consumerIndexMaps.begin(),
- std::next(consumerIndexMaps.begin(), consumerIdx));
- // Compute indexing maps for the producer args in the fused operation.
- computeProducerOperandIndex(
- producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
-
- // Append the indexing maps for the remaining consumer operands.
- fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
- consumerIndexMaps.end());
-
- // Generate the fused op.
- // Tensor-level fusion is only on ops without initTensors and outputBuffers.
- LinalgOp fusedOp;
- if (isa<GenericOp>(producer.getOperation()) &&
- isa<GenericOp>(consumer.getOperation())) {
- fusedOp =
- rewriter
- .create<GenericOp>(consumer.getLoc(),
- consumer.getOperation()->getResultTypes(),
- /*inputs=*/fusedOperands,
- /*outputBuffers=*/ValueRange{},
- /*initTensors=*/ValueRange{},
- rewriter.getArrayAttr(fusedIndexMaps),
- consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr,
- /*symbol_source=*/nullptr)
- .getOperation();
- } else {
- fusedOp =
- rewriter
- .create<IndexedGenericOp>(
- consumer.getLoc(), consumer.getOperation()->getResultTypes(),
- /*inputs=*/fusedOperands,
- /*outputBuffers=*/ValueRange{},
- /*initTensors=*/ValueRange{},
- rewriter.getArrayAttr(fusedIndexMaps),
- consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr,
- /*symbol_source=*/nullptr)
- .getOperation();
- }
-
- // Construct an AffineMap from consumer loops to producer loops.
- // consumer loop -> tensor index
- AffineMap consumerResultIndexMap =
- consumer.getInputIndexingMap(consumerIdx);
- // producer loop -> tensor index
- AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
- // tensor index -> producer loop
- AffineMap invProducerResultIndexMap =
- inversePermutation(producerResultIndexMap);
- assert(invProducerResultIndexMap &&
- "expected producer result indexig map to be invertible");
- // consumer loop -> producer loop
- AffineMap consumerToProducerLoopsMap =
- invProducerResultIndexMap.compose(consumerResultIndexMap);
-
- generateFusedRegion(rewriter, fusedOp, producer, consumer,
- consumerToProducerLoopsMap, consumerIdx,
- consumer.getNumLoops());
- return fusedOp;
- }
-
-private:
- /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
- /// the `producer` to use in the fused operation given the indexing map of the
- /// result of the producer in the consumer.
- static void computeProducerOperandIndex(
- LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
- SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
- // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
- // from consumer loop -> consumer arg tensor index/producer result tensor
- // index. The fused loop is same as the consumer loop. For each producer arg
- // the indexing map to be computed is a map from consumer loop -> producer
- // arg tensor index.
-
- AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
- // producerResultIndexMap is a map from producer loop -> tensor index.
- // Compute the inverse to get map from tensor index -> producer loop.
- // The inverse is a map from producer result tensor index -> producer loop.
- AffineMap invProducerResultIndexMap =
- inversePermutation(producerResultIndexMap);
- assert(invProducerResultIndexMap &&
- "expected producer result indexig map to be invertible");
- for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
- // argMap is a map from producer loop -> producer arg tensor index.
- AffineMap argMap = producer.getInputIndexingMap(argNum);
-
- // Compose argMap with invProducerResultIndexMap to get a map from
- // producer result tensor index -> producer arg tensor index.
- AffineMap t1 = argMap.compose(invProducerResultIndexMap);
-
- // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
- // consumer loop/ fused loop -> producer arg tensor index.
- AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
- fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
- }
- }
-
- /// Generate the region of the fused operation. The region of the fused op
- /// must be empty.
- static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp,
- LinalgOp producer, LinalgOp consumer,
- AffineMap consumerToProducerLoopsMap,
- unsigned consumerIdx, unsigned nloops) {
- // Build the region of the fused op.
- Block &producerBlock = producer.getOperation()->getRegion(0).front();
- Block &consumerBlock = consumer.getOperation()->getRegion(0).front();
- Block *fusedBlock = new Block();
- fusedOp->getRegion(0).push_back(fusedBlock);
- BlockAndValueMapping mapper;
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(fusedBlock);
-
- // The block arguments are
- // [index_0, index_1, ... ,
- // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
- // producer_operand_0, ... , producer_operand_(n-1)],
- // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
- // , where n is the number of producer's operand and m is the number
- // consumer's operand.
- // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
- // generic op. In this case, there are no indices in block arguments.
- unsigned numProducerIndices =
- isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
- unsigned numConsumerIndices =
- isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
- // Firstly, add all the indices to the block arguments.
- for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
- i < e; ++i)
- fusedBlock->addArgument(rewriter.getIndexType());
- // Map the arguments for the unmodified args from the consumer.
- for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
- if (consumerArg.index() == consumerIdx + numConsumerIndices) {
- // Map the arguments for the args from the producer.
- for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) {
- // If producer is an indexed_generic op, map the indices from consumer
- // loop to producer loop (because the fusedOp is built based on
- // consumer's perspective).
- if (producerArg.index() < numProducerIndices) {
- auto newIndex = rewriter.create<mlir::AffineApplyOp>(
- producer.getLoc(),
- consumerToProducerLoopsMap.getSubMap(producerArg.index()),
- fusedBlock->getArguments().take_front(nloops));
- mapper.map(producerArg.value(), newIndex);
- } else {
- mapper.map(producerArg.value(),
- fusedBlock->addArgument(producerArg.value().getType()));
- }
- }
- continue;
- }
-
- // If consumer is an indexed_generic op, map the indices to the block
- // arguments directly. Otherwise, add the same type of arugment and map to
- // it.
- if (consumerArg.index() < numConsumerIndices) {
- mapper.map(consumerArg.value(),
- fusedBlock->getArgument(consumerArg.index()));
- } else {
- mapper.map(consumerArg.value(),
- fusedBlock->addArgument(consumerArg.value().getType()));
- }
- }
-
- // Add operations from producer (except the yield operation) to the fused
- // op.
- for (auto &op : producerBlock.getOperations()) {
- if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
- // Lookup the value the yield operation is mapped to.
- Value yieldVal = yieldOp.getOperand(0);
- if (Value clonedVal = mapper.lookupOrNull(yieldVal))
- mapper.map(
- consumerBlock.getArgument(consumerIdx + numConsumerIndices),
- clonedVal);
- continue;
- }
- rewriter.clone(op, mapper);
- }
- for (auto &op : consumerBlock.getOperations())
- rewriter.clone(op, mapper);
- }
-};
-} // namespace
-
-/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
-/// provided, given the shape of the source tensor that corresponds to the
-/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
-/// are "row-major" ordered logically.
-///
-/// For example:
-///
-/// %0 = op ... : tensor<?x?x4x5xf32>
-/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
-///
-/// and reshape:
-/// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
-/// affine_map<(i, j, k, l) -> (j, k, l)>] :
-/// tensor<?x?x4x5xf32> into tensor<?x?xf32>
-///
-/// would be rewritten into:
-/// %0 = op ... : tensor<?x?x4x5xf32>
-/// with output index_map
-/// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
-static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
- ArrayRef<int64_t> sourceShape,
- ArrayRef<AffineMap> reassociationMaps) {
- SmallVector<AffineExpr, 4> resultExprs;
- resultExprs.reserve(reassociationMaps.size());
- ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
- MLIRContext *context = sourceMap.getContext();
-
- // Compute the result exprs based on the reassociation maps.
- for (AffineMap map : reassociationMaps) {
- ArrayRef<AffineExpr> collapsedDims = map.getResults();
- // Assume that they are in-order and contiguous (already checked in
- // verifier).
- assert(!collapsedDims.empty());
- unsigned startDim =
- collapsedDims.front().cast<AffineDimExpr>().getPosition();
- AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
- sourceShape.slice(startDim, collapsedDims.size()),
- sourceExprs.slice(startDim, collapsedDims.size()), context);
- resultExprs.push_back(linearizedExpr);
- }
- return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
- resultExprs, context);
-}
-
-/// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
-/// true) or its producer (if `asProducer` is false) given the indexing map at
-/// its use.
-static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
- AffineMap useIndexMap, bool asProducer) {
- RankedTensorType returnType = reshapeOp.getResultType();
- RankedTensorType operandType = reshapeOp.getSrcType();
- // Reshape is fusible with its consumer (i.e. reshape as a producer) when its
- // operand is of lesser rank than the result. Fusing when operand has higher
- // rank will require use of mods and divs in the indexing maps of the fused op
- // which would make it non-invertible. Similarly reshape is fused with its
- // producer (i.e. reshape as consumer) only if the return type has lesser
- // rank.
- if ((asProducer && returnType.getRank() < operandType.getRank()) ||
- (!asProducer && operandType.getRank() < returnType.getRank()))
- return false;
- return useIndexMap.isIdentity();
-}
-
-/// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
-/// is a linalg.generic operation, the create a `linalg.generic` operation with
-/// the given `args`. Expects `op` to be `linalg.generic` or
-/// `linalg.indexed_generic`.
-template <typename... Args>
-static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
- Args... args) {
- if (isa<GenericOp>(op.getOperation()))
- return cast<LinalgOp>(rewriter.create<GenericOp>(args...).getOperation());
- if (isa<IndexedGenericOp>(op.getOperation()))
- return cast<LinalgOp>(
- rewriter.create<IndexedGenericOp>(args...).getOperation());
- llvm_unreachable(
- "expected only linalg.generic or linalg.indexed_generic ops");
- return nullptr;
-}
-
-namespace {
-
-/// Implementation of fusion on tensor ops when producer is a TensorReshapeOp.
-struct FuseTensorReshapeOpAsProducer {
- static bool isFusible(TensorReshapeOp producer, LinalgOp consumer,
- unsigned consumerIdx) {
- return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
- consumer.hasTensorSemantics() &&
- isTensorReshapeOpFusible(producer,
- consumer.getInputIndexingMap(consumerIdx),
- /*asProducer=*/true);
- }
-
- static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer,
- unsigned consumerIdx, PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
- if (producer.src().getDefiningOp<ConstantOp>())
- return nullptr;
-
- if (!isFusible(producer, consumer, consumerIdx))
- return nullptr;
-
- // Compute the fused operands list,
- Operation *consumerOp = consumer.getOperation();
- SmallVector<Value, 2> fusedOperands(consumerOp->getOperands());
- fusedOperands[consumerIdx] = producer.src();
-
- // Compute indexing_maps for the fused operation. The indexing_maps for the
- // operands of the consumers that arent fused are the same.
- SmallVector<AffineMap, 4> fusedIndexMaps =
- llvm::to_vector<4>(llvm::map_range(
- consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
- return attr.cast<AffineMapAttr>().getValue();
- }));
-
- // Compute the indexing map to use for the operand of the producer.
- AffineMap modifiedMap = linearizeCollapsedDims(
- fusedIndexMaps[consumerIdx], producer.getResultType().getShape(),
- producer.getReassociationMaps());
- for (AffineExpr expr : modifiedMap.getResults()) {
- if (!expr.isPureAffine())
- return nullptr;
- }
- fusedIndexMaps[consumerIdx] = modifiedMap;
-
- // Further check that the resulting index maps can be fused and
- // inverted. Without this the resultant op is not legal.
- if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
- return nullptr;
-
- SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
- llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
- return AffineMapAttr::get(map);
- }));
- LinalgOp fusedOp = createLinalgOpOfSameType(
- consumer, rewriter, rewriter.getUnknownLoc(),
- consumerOp->getResultTypes(),
- /*inputs=*/fusedOperands,
- /*outputBuffers=*/ValueRange{},
- /*initTensors=*/ValueRange{}, // no init tensors for now.
- rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr,
- /*symbol_source=*/nullptr);
- auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
- rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion,
- fusedRegion.begin());
- return fusedOp;
- }
-};
-
-/// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
-struct FuseTensorReshapeOpAsConsumer {
- static bool isCollapsingAndFusible(LinalgOp producer,
- TensorReshapeOp consumer,
- unsigned consumerIdx) {
- return isa<GenericOp, IndexedGenericOp>(producer.getOperation()) &&
- producer.hasTensorSemantics() &&
- isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
- /*asProducer=*/false);
- }
-
- static LinalgOp fuseCollapsingCase(LinalgOp producer,
- TensorReshapeOp consumer,
- unsigned consumerIdx,
- PatternRewriter &rewriter) {
- // The indexing_maps for the operands of the fused operation are same as
- // those for the operands of the producer.
- SmallVector<AffineMap, 4> fusedIndexMaps =
- llvm::to_vector<4>(llvm::map_range(
- producer.indexing_maps(), [](Attribute attr) -> AffineMap {
- return attr.cast<AffineMapAttr>().getValue();
- }));
- // Compute the indexing map to use for the operand of the producer.
- AffineMap modifiedMap = linearizeCollapsedDims(
- producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(),
- consumer.getReassociationMaps());
- for (AffineExpr expr : modifiedMap.getResults()) {
- if (!expr.isPureAffine())
- return nullptr;
- }
- fusedIndexMaps.back() = modifiedMap;
-
- // Further check that the resulting index maps can be fused and
- // inverted. Without this the resultant op is not legal.
- if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
- return nullptr;
-
- SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
- llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
- return AffineMapAttr::get(map);
- }));
-
- Operation *producerOp = producer.getOperation();
- LinalgOp fusedOp = createLinalgOpOfSameType(
- producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(),
- /*inputs=*/producerOp->getOperands(),
- /*outputBuffers=*/ValueRange{},
- /*initTensors=*/ValueRange{}, // no init tensors for now.
- rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr,
- /*symbol_source=*/nullptr);
- auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
- rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion,
- fusedRegion.begin());
- return fusedOp;
- }
-
- static bool isExpandingAndFusible(LinalgOp producer, TensorReshapeOp consumer,
- unsigned consumerIdx) {
- // Is fusible only if:
- // 1) The producer is a generic op.
- // 2) The producer has tensor semantics.
- // 3) The tensor reshape op is a expanding case.
- // 4) All the shapes are the same for the generic op.
- // 5) All the indexing maps in producer are identity.
- // 6) All the loops in producer are parallel loops.
- // 7) The producer has a single user.
- auto types = producer.getInputOutputShapedTypes();
- assert(!types.empty());
- return isa<GenericOp>(producer.getOperation()) &&
- producer.hasTensorSemantics() &&
- consumer.getSrcType().getRank() <
- consumer.getResultType().getRank() &&
- std::equal(types.begin() + 1, types.end(), types.begin()) &&
- llvm::all_of(producer.getIndexingMaps(),
- [](AffineMap map) { return map.isIdentity(); }) &&
- llvm::all_of(producer.iterator_types(),
- [](Attribute attr) {
- return attr.cast<StringAttr>().getValue() ==
- getParallelIteratorTypeName();
- }) &&
- producer.getOperation()->hasOneUse();
- }
-
- static LinalgOp fuseExpandingCase(LinalgOp producer, TensorReshapeOp consumer,
- unsigned consumerIdx,
- PatternRewriter &rewriter) {
- Location loc = producer.getLoc();
- auto dstShape = consumer.getResultType().cast<ShapedType>().getShape();
- SmallVector<Value, 4> args;
- for (auto arg : producer.getOperation()->getOperands()) {
- auto type = RankedTensorType::get(
- dstShape, arg.getType().cast<ShapedType>().getElementType());
- args.push_back(rewriter.createOrFold<linalg::TensorReshapeOp>(
- loc, type, arg, consumer.reassociation()));
- }
-
- SmallVector<Type, 4> resultTypes;
- for (auto t : producer.getOutputTensorTypes()) {
- Type type = RankedTensorType::get(dstShape,
- t.cast<ShapedType>().getElementType());
- resultTypes.push_back(type);
- }
-
- int rank = dstShape.size();
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, resultTypes, /*inputs=*/args,
- /*outputBuffers=*/ValueRange{},
- /*initTensors=*/ValueRange{},
- SmallVector<AffineMap, 3>(args.size() + resultTypes.size(),
- rewriter.getMultiDimIdentityMap(rank)),
- SmallVector<StringRef, 3>(rank, getParallelIteratorTypeName()));
- Region ®ion = genericOp.getRegion();
- rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region,
- region.begin());
- return cast<LinalgOp>(genericOp.getOperation());
- }
-
- static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer,
- unsigned consumerIdx, PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
- if (isCollapsingAndFusible(producer, consumer, consumerIdx))
- return fuseCollapsingCase(producer, consumer, consumerIdx, rewriter);
- if (isExpandingAndFusible(producer, consumer, consumerIdx))
- return fuseExpandingCase(producer, consumer, consumerIdx, rewriter);
- return nullptr;
- }
-};
-
-/// Implementation of fusion on tensor ops when producer is a splat constant.
-struct FuseConstantOpAsProducer {
- static bool isFusible(ConstantOp producer, LinalgOp consumer,
- unsigned consumerIdx) {
- return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
- consumer.hasTensorSemantics() &&
- producer.getResult().getType().isa<RankedTensorType>() &&
- producer.value().cast<DenseElementsAttr>().isSplat();
- }
-
- static LinalgOp fuse(ConstantOp producer, LinalgOp consumer,
- unsigned consumerIdx, PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
- if (!isFusible(producer, consumer, consumerIdx))
- return nullptr;
-
- // The indexing_maps for the operands of the fused operation are same as
- // those for the operands of the consumer without the indexing map at
- // consumerIdx
- SmallVector<AffineMap, 4> fusedIndexMaps =
- llvm::to_vector<4>(llvm::map_range(
- consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
- return attr.cast<AffineMapAttr>().getValue();
- }));
- fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx));
-
- // The operands list is same as the consumer with the argument for constant
- // index dropped.
- Operation *consumerOp = consumer.getOperation();
- SmallVector<Value, 4> fusedOperands(consumerOp->getOperands());
- fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx));
-
- // Create a constant scalar value from the splat constant.
- Value scalarConstant = rewriter.create<ConstantOp>(
- producer.getLoc(),
- producer.value().cast<DenseElementsAttr>().getSplatValue());
-
- LinalgOp fusedOp = createLinalgOpOfSameType(
- consumer, rewriter, rewriter.getUnknownLoc(),
- consumerOp->getResultTypes(),
- /*inputs=*/fusedOperands,
- /*outputBuffers=*/ValueRange{},
- /*initTensors=*/ValueRange{}, // no init tensors for now.
- rewriter.getAffineMapArrayAttr(fusedIndexMaps),
- consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr,
- /*symbol_source=*/nullptr);
-
- // Map the block argument corresponding to the replaced argument with the
- // scalar constant.
- Region &consumerRegion = consumerOp->getRegion(0);
- Block &entryBlock = *consumerRegion.begin();
- unsigned argIndex = entryBlock.getNumArguments() -
- consumerOp->getNumOperands() + consumerIdx;
- BlockAndValueMapping mapping;
- mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
- Region &fusedRegion = fusedOp.getOperation()->getRegion(0);
- rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(),
- mapping);
- return fusedOp;
- }
-};
-} // namespace
-
-Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
- Operation *consumer,
- unsigned consumerIdx,
- OperationFolder *folder) {
- if (consumerIdx >= consumer->getNumOperands())
- return nullptr;
- Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
- if (!producer || producer->getNumResults() != 1)
- return nullptr;
-
- // Fuse when consumer is GenericOp or IndexedGenericOp.
- if (isa<GenericOp, IndexedGenericOp>(consumer)) {
- if (isa<GenericOp, IndexedGenericOp>(producer))
- return FuseGenericOpsOnTensors::fuse(cast<LinalgOp>(producer),
- cast<LinalgOp>(consumer),
- consumerIdx, rewriter, folder);
- if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer))
- return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer,
- cast<LinalgOp>(consumer),
- consumerIdx, rewriter, folder);
- if (auto constantOpProducer = dyn_cast<ConstantOp>(producer))
- return FuseConstantOpAsProducer::fuse(constantOpProducer,
- cast<LinalgOp>(consumer),
- consumerIdx, rewriter, folder);
- return nullptr;
- }
-
- if (isa<GenericOp, IndexedGenericOp>(producer)) {
- // Fuse when consumer is a TensorReshapeOp.
- if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
- return FuseTensorReshapeOpAsConsumer::fuse(
- cast<LinalgOp>(producer), reshapeOp, consumerIdx, rewriter, folder);
- }
- }
-
- return nullptr;
-}
-
namespace {
-/// Patterns to fuse a generic op, with the producer of its operands.
-template <typename LinalgOpTy>
-struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
- using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(LinalgOpTy op,
- PatternRewriter &rewriter) const override {
- // Find the first operand that is defined by another generic op on tensors.
- for (auto operandNum :
- llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) {
- Operation *producer =
- op.getOperation()->getOperand(operandNum).getDefiningOp();
- if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) {
- rewriter.replaceOp(op, fusedOp->getResults());
- if (producer && llvm::all_of(producer->getResults(),
- [](Value val) { return val.use_empty(); }))
- rewriter.eraseOp(producer);
- return success();
- }
- }
- return failure();
- }
-};
-
-/// Pass that fuses generic ops on tensors. Used only for testing.
-struct FusionOfTensorOpsPass
- : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
- void runOnOperation() override {
- OwningRewritePatternList patterns;
- Operation *op = getOperation();
- populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
- applyPatternsAndFoldGreedily(op->getRegions(), patterns);
- };
-};
-
struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
};
} // namespace
-void mlir::populateLinalgTensorOpsFusionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
- patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
- FuseTensorOps<TensorReshapeOp>>(context);
-}
-
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
return std::make_unique<LinalgFusionPass>();
}
-
-std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
- return std::make_unique<FusionOfTensorOpsPass>();
-}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
new file mode 100644
index 000000000000..a62b1ada2c18
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -0,0 +1,698 @@
+//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the linalg dialect Fusion on tensors operations pass.
+//
+//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// Implementation of fusion of generic ops and indexed_generic ops.
+struct FuseGenericOpsOnTensors {
+ static bool isFusible(LinalgOp producer, LinalgOp consumer,
+ unsigned consumerIdx) {
+ // Producer and consumer must have tensor semantics.
+ if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
+ return false;
+
+ // Verify that
+ // - the producer has all "parallel" iterator type.
+ if (producer.getNumParallelLoops() != producer.getNumLoops())
+ return false;
+
+ // Get the consumer index map. The number of results of the consumer index
+ // map must match the number of loops of the producer.
+ AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
+ if (consumerIndexMap.getNumResults() != producer.getNumLoops())
+ return false;
+
+ // Finally the index_map for the result must be invertible. For now just
+ // verify it is a permutation.
+ AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+ return producerResultIndexMap.isPermutation();
+ }
+
+ static LinalgOp fuse(LinalgOp producer, LinalgOp consumer,
+ unsigned consumerIdx, PatternRewriter &rewriter,
+ OperationFolder *folder = nullptr) {
+ if (!isFusible(producer, consumer, consumerIdx))
+ return nullptr;
+
+ unsigned numFusedOperands = producer.getOperation()->getNumOperands() +
+ consumer.getOperation()->getNumOperands() - 1;
+
+ // Compute the fused operands list,
+ SmallVector<Value, 2> fusedOperands;
+ fusedOperands.reserve(numFusedOperands);
+ auto consumerOperands = consumer.getOperation()->getOperands();
+ auto producerOperands = producer.getOperation()->getOperands();
+ fusedOperands.assign(consumerOperands.begin(),
+ std::next(consumerOperands.begin(), consumerIdx));
+ fusedOperands.append(producerOperands.begin(), producerOperands.end());
+ fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
+ consumerOperands.end());
+
+ // Compute indexing_maps for the fused operation. The indexing_maps for the
+ // operands of the consumers that arent fused are the same. The
+ // indexing_maps for the producers need to be computed based on the
+ // indexing_map of the operand at consumerIdx in the consumer.
+ SmallVector<Attribute, 4> fusedIndexMaps;
+ auto consumerIndexMaps = consumer.indexing_maps();
+ fusedIndexMaps.reserve(fusedOperands.size() +
+ consumer.getOperation()->getNumResults());
+ fusedIndexMaps.assign(consumerIndexMaps.begin(),
+ std::next(consumerIndexMaps.begin(), consumerIdx));
+ // Compute indexing maps for the producer args in the fused operation.
+ computeProducerOperandIndex(
+ producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
+
+ // Append the indexing maps for the remaining consumer operands.
+ fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
+ consumerIndexMaps.end());
+
+ // Generate the fused op.
+ // Tensor-level fusion is only on ops without initTensors and outputBuffers.
+ LinalgOp fusedOp;
+ if (isa<GenericOp>(producer.getOperation()) &&
+ isa<GenericOp>(consumer.getOperation())) {
+ fusedOp =
+ rewriter
+ .create<GenericOp>(consumer.getLoc(),
+ consumer.getOperation()->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{},
+ rewriter.getArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr)
+ .getOperation();
+ } else {
+ fusedOp =
+ rewriter
+ .create<IndexedGenericOp>(
+ consumer.getLoc(), consumer.getOperation()->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{},
+ rewriter.getArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr)
+ .getOperation();
+ }
+
+ // Construct an AffineMap from consumer loops to producer loops.
+ // consumer loop -> tensor index
+ AffineMap consumerResultIndexMap =
+ consumer.getInputIndexingMap(consumerIdx);
+ // producer loop -> tensor index
+ AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+ // tensor index -> producer loop
+ AffineMap invProducerResultIndexMap =
+ inversePermutation(producerResultIndexMap);
+ assert(invProducerResultIndexMap &&
+ "expected producer result indexig map to be invertible");
+ // consumer loop -> producer loop
+ AffineMap consumerToProducerLoopsMap =
+ invProducerResultIndexMap.compose(consumerResultIndexMap);
+
+ generateFusedRegion(rewriter, fusedOp, producer, consumer,
+ consumerToProducerLoopsMap, consumerIdx,
+ consumer.getNumLoops());
+ return fusedOp;
+ }
+
+private:
+ /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
+ /// the `producer` to use in the fused operation given the indexing map of the
+ /// result of the producer in the consumer.
+ static void computeProducerOperandIndex(
+ LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
+ SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
+ // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
+ // from consumer loop -> consumer arg tensor index/producer result tensor
+ // index. The fused loop is same as the consumer loop. For each producer arg
+ // the indexing map to be computed is a map from consumer loop -> producer
+ // arg tensor index.
+
+ AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+ // producerResultIndexMap is a map from producer loop -> tensor index.
+ // Compute the inverse to get map from tensor index -> producer loop.
+ // The inverse is a map from producer result tensor index -> producer loop.
+ AffineMap invProducerResultIndexMap =
+ inversePermutation(producerResultIndexMap);
+ assert(invProducerResultIndexMap &&
+ "expected producer result indexig map to be invertible");
+ for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
+ // argMap is a map from producer loop -> producer arg tensor index.
+ AffineMap argMap = producer.getInputIndexingMap(argNum);
+
+ // Compose argMap with invProducerResultIndexMap to get a map from
+ // producer result tensor index -> producer arg tensor index.
+ AffineMap t1 = argMap.compose(invProducerResultIndexMap);
+
+ // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
+ // consumer loop/ fused loop -> producer arg tensor index.
+ AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
+ fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
+ }
+ }
+
+ /// Generate the region of the fused operation. The region of the fused op
+ /// must be empty.
+ static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp,
+ LinalgOp producer, LinalgOp consumer,
+ AffineMap consumerToProducerLoopsMap,
+ unsigned consumerIdx, unsigned nloops) {
+ // Build the region of the fused op.
+ Block &producerBlock = producer.getOperation()->getRegion(0).front();
+ Block &consumerBlock = consumer.getOperation()->getRegion(0).front();
+ Block *fusedBlock = new Block();
+ fusedOp->getRegion(0).push_back(fusedBlock);
+ BlockAndValueMapping mapper;
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(fusedBlock);
+
+ // The block arguments are
+ // [index_0, index_1, ... ,
+ // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
+ // producer_operand_0, ... , producer_operand_(n-1)],
+ // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
+ // , where n is the number of producer's operand and m is the number
+ // consumer's operand.
+ // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
+ // generic op. In this case, there are no indices in block arguments.
+ unsigned numProducerIndices =
+ isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
+ unsigned numConsumerIndices =
+ isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
+ // Firstly, add all the indices to the block arguments.
+ for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
+ i < e; ++i)
+ fusedBlock->addArgument(rewriter.getIndexType());
+ // Map the arguments for the unmodified args from the consumer.
+ for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
+ if (consumerArg.index() == consumerIdx + numConsumerIndices) {
+ // Map the arguments for the args from the producer.
+ for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) {
+ // If producer is an indexed_generic op, map the indices from consumer
+ // loop to producer loop (because the fusedOp is built based on
+ // consumer's perspective).
+ if (producerArg.index() < numProducerIndices) {
+ auto newIndex = rewriter.create<mlir::AffineApplyOp>(
+ producer.getLoc(),
+ consumerToProducerLoopsMap.getSubMap(producerArg.index()),
+ fusedBlock->getArguments().take_front(nloops));
+ mapper.map(producerArg.value(), newIndex);
+ } else {
+ mapper.map(producerArg.value(),
+ fusedBlock->addArgument(producerArg.value().getType()));
+ }
+ }
+ continue;
+ }
+
+ // If consumer is an indexed_generic op, map the indices to the block
+ // arguments directly. Otherwise, add the same type of arugment and map to
+ // it.
+ if (consumerArg.index() < numConsumerIndices) {
+ mapper.map(consumerArg.value(),
+ fusedBlock->getArgument(consumerArg.index()));
+ } else {
+ mapper.map(consumerArg.value(),
+ fusedBlock->addArgument(consumerArg.value().getType()));
+ }
+ }
+
+ // Add operations from producer (except the yield operation) to the fused
+ // op.
+ for (auto &op : producerBlock.getOperations()) {
+ if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
+ // Lookup the value the yield operation is mapped to.
+ Value yieldVal = yieldOp.getOperand(0);
+ if (Value clonedVal = mapper.lookupOrNull(yieldVal))
+ mapper.map(
+ consumerBlock.getArgument(consumerIdx + numConsumerIndices),
+ clonedVal);
+ continue;
+ }
+ rewriter.clone(op, mapper);
+ }
+ for (auto &op : consumerBlock.getOperations())
+ rewriter.clone(op, mapper);
+ }
+};
+} // namespace
+
+/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
+/// provided, given the shape of the source tensor that corresponds to the
+/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
+/// are "row-major" ordered logically.
+///
+/// For example:
+///
+/// %0 = op ... : tensor<?x?x4x5xf32>
+/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
+///
+/// and reshape:
+/// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+/// affine_map<(i, j, k, l) -> (j, k, l)>] :
+/// tensor<?x?x4x5xf32> into tensor<?x?xf32>
+///
+/// would be rewritten into:
+/// %0 = op ... : tensor<?x?x4x5xf32>
+/// with output index_map
+/// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
+static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
+ ArrayRef<int64_t> sourceShape,
+ ArrayRef<AffineMap> reassociationMaps) {
+ SmallVector<AffineExpr, 4> resultExprs;
+ resultExprs.reserve(reassociationMaps.size());
+ ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
+ MLIRContext *context = sourceMap.getContext();
+
+ // Compute the result exprs based on the reassociation maps.
+ for (AffineMap map : reassociationMaps) {
+ ArrayRef<AffineExpr> collapsedDims = map.getResults();
+ // Assume that they are in-order and contiguous (already checked in
+ // verifier).
+ assert(!collapsedDims.empty());
+ unsigned startDim =
+ collapsedDims.front().cast<AffineDimExpr>().getPosition();
+ AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
+ sourceShape.slice(startDim, collapsedDims.size()),
+ sourceExprs.slice(startDim, collapsedDims.size()), context);
+ resultExprs.push_back(linearizedExpr);
+ }
+ return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
+ resultExprs, context);
+}
+
+/// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
+/// true) or its producer (if `asProducer` is false) given the indexing map at
+/// its use.
+static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
+ AffineMap useIndexMap, bool asProducer) {
+ RankedTensorType returnType = reshapeOp.getResultType();
+ RankedTensorType operandType = reshapeOp.getSrcType();
+ // Reshape is fusible with its consumer (i.e. reshape as a producer) when its
+ // operand is of lesser rank than the result. Fusing when operand has higher
+ // rank will require use of mods and divs in the indexing maps of the fused op
+ // which would make it non-invertible. Similarly reshape is fused with its
+ // producer (i.e. reshape as consumer) only if the return type has lesser
+ // rank.
+ if ((asProducer && returnType.getRank() < operandType.getRank()) ||
+ (!asProducer && operandType.getRank() < returnType.getRank()))
+ return false;
+ return useIndexMap.isIdentity();
+}
+
+/// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
+/// is a linalg.generic operation, the create a `linalg.generic` operation with
+/// the given `args`. Expects `op` to be `linalg.generic` or
+/// `linalg.indexed_generic`.
+template <typename... Args>
+static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
+ Args... args) {
+ if (isa<GenericOp>(op.getOperation()))
+ return cast<LinalgOp>(rewriter.create<GenericOp>(args...).getOperation());
+ if (isa<IndexedGenericOp>(op.getOperation()))
+ return cast<LinalgOp>(
+ rewriter.create<IndexedGenericOp>(args...).getOperation());
+ llvm_unreachable(
+ "expected only linalg.generic or linalg.indexed_generic ops");
+ return nullptr;
+}
+
+namespace {
+
+/// Implementation of fusion on tensor ops when producer is a TensorReshapeOp.
+struct FuseTensorReshapeOpAsProducer {
+ static bool isFusible(TensorReshapeOp producer, LinalgOp consumer,
+ unsigned consumerIdx) {
+ return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
+ consumer.hasTensorSemantics() &&
+ isTensorReshapeOpFusible(producer,
+ consumer.getInputIndexingMap(consumerIdx),
+ /*asProducer=*/true);
+ }
+
+ static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer,
+ unsigned consumerIdx, PatternRewriter &rewriter,
+ OperationFolder *folder = nullptr) {
+ if (producer.src().getDefiningOp<ConstantOp>())
+ return nullptr;
+
+ if (!isFusible(producer, consumer, consumerIdx))
+ return nullptr;
+
+ // Compute the fused operands list,
+ Operation *consumerOp = consumer.getOperation();
+ SmallVector<Value, 2> fusedOperands(consumerOp->getOperands());
+ fusedOperands[consumerIdx] = producer.src();
+
+ // Compute indexing_maps for the fused operation. The indexing_maps for the
+ // operands of the consumers that arent fused are the same.
+ SmallVector<AffineMap, 4> fusedIndexMaps =
+ llvm::to_vector<4>(llvm::map_range(
+ consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
+ return attr.cast<AffineMapAttr>().getValue();
+ }));
+
+ // Compute the indexing map to use for the operand of the producer.
+ AffineMap modifiedMap = linearizeCollapsedDims(
+ fusedIndexMaps[consumerIdx], producer.getResultType().getShape(),
+ producer.getReassociationMaps());
+ for (AffineExpr expr : modifiedMap.getResults()) {
+ if (!expr.isPureAffine())
+ return nullptr;
+ }
+ fusedIndexMaps[consumerIdx] = modifiedMap;
+
+ // Further check that the resulting index maps can be fused and
+ // inverted. Without this the resultant op is not legal.
+ if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
+ return nullptr;
+
+ SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
+ llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ }));
+ LinalgOp fusedOp = createLinalgOpOfSameType(
+ consumer, rewriter, rewriter.getUnknownLoc(),
+ consumerOp->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{}, // no init tensors for now.
+ rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr);
+ auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
+ rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion,
+ fusedRegion.begin());
+ return fusedOp;
+ }
+};
+
+/// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
+struct FuseTensorReshapeOpAsConsumer {
+ static bool isCollapsingAndFusible(LinalgOp producer,
+ TensorReshapeOp consumer,
+ unsigned consumerIdx) {
+ return isa<GenericOp, IndexedGenericOp>(producer.getOperation()) &&
+ producer.hasTensorSemantics() &&
+ isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
+ /*asProducer=*/false);
+ }
+
+ static LinalgOp fuseCollapsingCase(LinalgOp producer,
+ TensorReshapeOp consumer,
+ unsigned consumerIdx,
+ PatternRewriter &rewriter) {
+ // The indexing_maps for the operands of the fused operation are same as
+ // those for the operands of the producer.
+ SmallVector<AffineMap, 4> fusedIndexMaps =
+ llvm::to_vector<4>(llvm::map_range(
+ producer.indexing_maps(), [](Attribute attr) -> AffineMap {
+ return attr.cast<AffineMapAttr>().getValue();
+ }));
+ // Compute the indexing map to use for the operand of the producer.
+ AffineMap modifiedMap = linearizeCollapsedDims(
+ producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(),
+ consumer.getReassociationMaps());
+ for (AffineExpr expr : modifiedMap.getResults()) {
+ if (!expr.isPureAffine())
+ return nullptr;
+ }
+ fusedIndexMaps.back() = modifiedMap;
+
+ // Further check that the resulting index maps can be fused and
+ // inverted. Without this the resultant op is not legal.
+ if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
+ return nullptr;
+
+ SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
+ llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ }));
+
+ Operation *producerOp = producer.getOperation();
+ LinalgOp fusedOp = createLinalgOpOfSameType(
+ producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(),
+ /*inputs=*/producerOp->getOperands(),
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{}, // no init tensors for now.
+ rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr);
+ auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
+ rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion,
+ fusedRegion.begin());
+ return fusedOp;
+ }
+
+ static bool isExpandingAndFusible(LinalgOp producer, TensorReshapeOp consumer,
+ unsigned consumerIdx) {
+ // Is fusible only if:
+ // 1) The producer is a generic op.
+ // 2) The producer has tensor semantics.
+ // 3) The tensor reshape op is a expanding case.
+ // 4) All the shapes are the same for the generic op.
+ // 5) All the indexing maps in producer are identity.
+ // 6) All the loops in producer are parallel loops.
+ // 7) The producer has a single user.
+ auto types = producer.getInputOutputShapedTypes();
+ assert(!types.empty());
+ return isa<GenericOp>(producer.getOperation()) &&
+ producer.hasTensorSemantics() &&
+ consumer.getSrcType().getRank() <
+ consumer.getResultType().getRank() &&
+ std::equal(types.begin() + 1, types.end(), types.begin()) &&
+ llvm::all_of(producer.getIndexingMaps(),
+ [](AffineMap map) { return map.isIdentity(); }) &&
+ llvm::all_of(producer.iterator_types(),
+ [](Attribute attr) {
+ return attr.cast<StringAttr>().getValue() ==
+ getParallelIteratorTypeName();
+ }) &&
+ producer.getOperation()->hasOneUse();
+ }
+
+ static LinalgOp fuseExpandingCase(LinalgOp producer, TensorReshapeOp consumer,
+ unsigned consumerIdx,
+ PatternRewriter &rewriter) {
+ Location loc = producer.getLoc();
+ auto dstShape = consumer.getResultType().cast<ShapedType>().getShape();
+ SmallVector<Value, 4> args;
+ for (auto arg : producer.getOperation()->getOperands()) {
+ auto type = RankedTensorType::get(
+ dstShape, arg.getType().cast<ShapedType>().getElementType());
+ args.push_back(rewriter.createOrFold<linalg::TensorReshapeOp>(
+ loc, type, arg, consumer.reassociation()));
+ }
+
+ SmallVector<Type, 4> resultTypes;
+ for (auto t : producer.getOutputTensorTypes()) {
+ Type type = RankedTensorType::get(dstShape,
+ t.cast<ShapedType>().getElementType());
+ resultTypes.push_back(type);
+ }
+
+ int rank = dstShape.size();
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, resultTypes, /*inputs=*/args,
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{},
+ SmallVector<AffineMap, 3>(args.size() + resultTypes.size(),
+ rewriter.getMultiDimIdentityMap(rank)),
+ SmallVector<StringRef, 3>(rank, getParallelIteratorTypeName()));
+ Region ®ion = genericOp.getRegion();
+ rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region,
+ region.begin());
+ return cast<LinalgOp>(genericOp.getOperation());
+ }
+
+ static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer,
+ unsigned consumerIdx, PatternRewriter &rewriter,
+ OperationFolder *folder = nullptr) {
+ if (isCollapsingAndFusible(producer, consumer, consumerIdx))
+ return fuseCollapsingCase(producer, consumer, consumerIdx, rewriter);
+ if (isExpandingAndFusible(producer, consumer, consumerIdx))
+ return fuseExpandingCase(producer, consumer, consumerIdx, rewriter);
+ return nullptr;
+ }
+};
+
+/// Implementation of fusion on tensor ops when producer is a splat constant.
+struct FuseConstantOpAsProducer {
+ static bool isFusible(ConstantOp producer, LinalgOp consumer,
+ unsigned consumerIdx) {
+ return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
+ consumer.hasTensorSemantics() &&
+ producer.getResult().getType().isa<RankedTensorType>() &&
+ producer.value().cast<DenseElementsAttr>().isSplat();
+ }
+
+ static LinalgOp fuse(ConstantOp producer, LinalgOp consumer,
+ unsigned consumerIdx, PatternRewriter &rewriter,
+ OperationFolder *folder = nullptr) {
+ if (!isFusible(producer, consumer, consumerIdx))
+ return nullptr;
+
+ // The indexing_maps for the operands of the fused operation are same as
+ // those for the operands of the consumer without the indexing map at
+ // consumerIdx
+ SmallVector<AffineMap, 4> fusedIndexMaps =
+ llvm::to_vector<4>(llvm::map_range(
+ consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
+ return attr.cast<AffineMapAttr>().getValue();
+ }));
+ fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx));
+
+ // The operands list is same as the consumer with the argument for constant
+ // index dropped.
+ Operation *consumerOp = consumer.getOperation();
+ SmallVector<Value, 4> fusedOperands(consumerOp->getOperands());
+ fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx));
+
+ // Create a constant scalar value from the splat constant.
+ Value scalarConstant = rewriter.create<ConstantOp>(
+ producer.getLoc(),
+ producer.value().cast<DenseElementsAttr>().getSplatValue());
+
+ LinalgOp fusedOp = createLinalgOpOfSameType(
+ consumer, rewriter, rewriter.getUnknownLoc(),
+ consumerOp->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{}, // no init tensors for now.
+ rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr);
+
+ // Map the block argument corresponding to the replaced argument with the
+ // scalar constant.
+ Region &consumerRegion = consumerOp->getRegion(0);
+ Block &entryBlock = *consumerRegion.begin();
+ unsigned argIndex = entryBlock.getNumArguments() -
+ consumerOp->getNumOperands() + consumerIdx;
+ BlockAndValueMapping mapping;
+ mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
+ Region &fusedRegion = fusedOp.getOperation()->getRegion(0);
+ rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(),
+ mapping);
+ return fusedOp;
+ }
+};
+} // namespace
+
+Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
+ Operation *consumer,
+ unsigned consumerIdx,
+ OperationFolder *folder) {
+ if (consumerIdx >= consumer->getNumOperands())
+ return nullptr;
+ Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
+ if (!producer || producer->getNumResults() != 1)
+ return nullptr;
+
+ // Fuse when consumer is GenericOp or IndexedGenericOp.
+ if (isa<GenericOp, IndexedGenericOp>(consumer)) {
+ if (isa<GenericOp, IndexedGenericOp>(producer))
+ return FuseGenericOpsOnTensors::fuse(cast<LinalgOp>(producer),
+ cast<LinalgOp>(consumer),
+ consumerIdx, rewriter, folder);
+ if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer))
+ return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer,
+ cast<LinalgOp>(consumer),
+ consumerIdx, rewriter, folder);
+ if (auto constantOpProducer = dyn_cast<ConstantOp>(producer))
+ return FuseConstantOpAsProducer::fuse(constantOpProducer,
+ cast<LinalgOp>(consumer),
+ consumerIdx, rewriter, folder);
+ return nullptr;
+ }
+
+ if (isa<GenericOp, IndexedGenericOp>(producer)) {
+ // Fuse when consumer is a TensorReshapeOp.
+ if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
+ return FuseTensorReshapeOpAsConsumer::fuse(
+ cast<LinalgOp>(producer), reshapeOp, consumerIdx, rewriter, folder);
+ }
+ }
+
+ return nullptr;
+}
+
+namespace {
+/// Patterns to fuse a generic op, with the producer of its operands.
+template <typename LinalgOpTy>
+struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
+ using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(LinalgOpTy op,
+ PatternRewriter &rewriter) const override {
+ // Find the first operand that is defined by another generic op on tensors.
+ for (auto operandNum :
+ llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) {
+ Operation *producer =
+ op.getOperation()->getOperand(operandNum).getDefiningOp();
+ if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) {
+ rewriter.replaceOp(op, fusedOp->getResults());
+ if (producer && llvm::all_of(producer->getResults(),
+ [](Value val) { return val.use_empty(); }))
+ rewriter.eraseOp(producer);
+ return success();
+ }
+ }
+ return failure();
+ }
+};
+
+/// Pass that fuses generic ops on tensors. Used only for testing.
+struct FusionOfTensorOpsPass
+ : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
+ void runOnOperation() override {
+ OwningRewritePatternList patterns;
+ Operation *op = getOperation();
+ populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
+ applyPatternsAndFoldGreedily(op->getRegions(), patterns);
+ };
+};
+} // namespace
+
+void mlir::populateLinalgTensorOpsFusionPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
+ FuseTensorOps<TensorReshapeOp>>(context);
+}
+
+std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
+ return std::make_unique<FusionOfTensorOpsPass>();
+}
More information about the Mlir-commits
mailing list