[Mlir-commits] [mlir] 485190d - [mlir][Linalg] Deprecate `tileAndFuseLinalgOps` method and associated patterns.
Mahesh Ravishankar
llvmlistbot at llvm.org
Wed Jul 20 22:05:21 PDT 2022
Author: Mahesh Ravishankar
Date: 2022-07-21T05:05:06Z
New Revision: 485190df95f98c51c3f4a4ab4db96127cdc9ce78
URL: https://github.com/llvm/llvm-project/commit/485190df95f98c51c3f4a4ab4db96127cdc9ce78
DIFF: https://github.com/llvm/llvm-project/commit/485190df95f98c51c3f4a4ab4db96127cdc9ce78.diff
LOG: [mlir][Linalg] Deprecate `tileAndFuseLinalgOps` method and associated patterns.
The `tileAndFuseLinalgOps` is a legacy approach for tiling + fusion of
Linalg operations. Since it was also intended to work on operations
with buffer operands, this method had fairly complex logic to make
sure tile and fuse was correct even with side-effecting linalg ops.
While complex, it still wasnt robust enough. This patch deprecates
this method and thereby deprecating the tiling + fusion method for ops
with buffer semantics. Note that the core transformation to do fusion
of a producer with a tiled consumer still exists. The deprecation here
only removes methods that auto-magically tried to tile and fuse
correctly in presence of side-effects.
The `tileAndFuseLinalgOps` also works with operations with tensor
semantics. There are at least two other ways the same functionality
exists.
1) The `tileConsumerAndFuseProducers` method. This does a similar
transformation, but using a slightly different logic to
automatically figure out the legal tile + fuse code. Note that this
is also to be deprecated soon.
2) The prefered way uses the `TilingInterface` for tile + fuse, and
relies on the caller to set the tiling options correctly to ensure
that the generated code is correct.
As proof that (2) is equivalent to the functionality provided by
`tileAndFuseLinalgOps`, relevant tests have been moved to use the
interface, where the test driver sets the tile sizes appropriately to
generate the expected code.
Differential Revision: https://reviews.llvm.org/D129901
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
mlir/test/Dialect/Linalg/fusion-pattern.mlir
mlir/test/Dialect/Linalg/fusion-sequence.mlir
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 68a41fd1b14d..c81a7ee2ac32 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -169,71 +169,6 @@ void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
ArrayRef<int64_t> peeledLoops,
LinalgTilingLoopType loopType);
-/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
-/// proceeds as follows:
-/// - Find outer parallel loops in these ops that can be fused.
-/// - Tile fusable outer parallel loops of the last operation in the sequence.
-/// - Fuse the remaining operations with the tiled operation
-///
-/// For example, consider the sequence of matmul below
-///
-/// linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>)
-/// outs(%arg2 : memref<256x32xf32>)
-/// linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>)
-/// outs(%arg4 : memref<256x32xf32>)
-///
-/// It is legal to fuse the RAW dependence (through %arg2) by only fusing the
-/// matmuls row-wise. For example, the fused computation for the above is shown
-/// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling
-/// along the rows of the matrix. The entire rows of the first matmul operation
-/// need to be computed before they can be used for the second matmul. The
-/// second matmul is further tiled (similar to normal tiling).
-///
-/// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
-/// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
-/// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) {
-/// %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1]
-/// : memref<256x32xf32> to memref<16x32xf32, #map0>
-/// %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1]
-/// : memref<256x32xf32> to memref<16x32xf32, #map0>
-/// %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1]
-/// : memref<256x32xf32> to memref<16x32xf32, #map0>
-/// %3 = subview %arg1[0, 0] [32, 32] [1, 1]
-/// : memref<32x32xf32> to memref<32x32xf32, #map1>
-/// %4 = subview %arg3[0, 0] [32, 32] [1, 1]
-/// : memref<32x32xf32> to memref<32x32xf32, #map1>
-/// linalg.matmul
-/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
-/// outs(%0 : memref<16x32xf32, #map0>)
-/// linalg.matmul
-/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
-/// outs(%1 : memref<16x8xf32, #map0>)
-/// }
-///
-/// `tilingOptions` are used to tile the corresponding operation in `ops` (the
-/// size of the former should be same as size of the latter. Based on how
-/// tile+fuse is implemented, the fused loops are generated based on the last
-/// operation in the sequence. For example, the tile sizes for the fused loops
-/// is obtained from `tilingOptions.back()`. The following tiling options are
-/// handled
diff erently in tile+fuse (compared to tile only)
-/// - Interchange of the tiling loops is not supported right now.
-/// - Only the fused loops are distributed.
-struct TiledAndFusedLinalgOps {
- /// Operation obtained by tiling the last operation in sequence of `ops`
- /// passed to `tileAndFuseLinalgOps`.
- LinalgOp op;
- /// The dimension of the loops that are fused.
- std::set<unsigned> fusedLoopDims;
- /// The generated fused operations (created within the fused loops).
- SmallVector<LinalgOp, 1> fusedProducers;
- /// The fused loop generated.
- SmallVector<Operation *, 4> fusedLoops;
-};
-FailureOr<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
- const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions);
-
/// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts
/// the index accesses of `op`. This is an in-place transformation controlled by
/// `interchangeVector`. An empty vector is interpreted as the identity
@@ -847,62 +782,6 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
LinalgTransformationFilter filter;
};
-struct LinalgFusionOptions {
- /// List of operands indices to use for fusion.
- llvm::SmallSet<unsigned, 1> indicesToFuse = {};
- LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) {
- indicesToFuse.insert(operands.begin(), operands.end());
- return *this;
- }
-};
-
-struct LinalgBaseTileAndFusePattern : public RewritePattern {
- LinalgBaseTileAndFusePattern(
- StringRef opName, MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
- LinalgTransformationFilter originalOpMarker =
- LinalgTransformationFilter(),
- PatternBenefit benefit = 1);
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
-
-private:
- /// Dependence graph needed for fusion.
- const LinalgDependenceGraph &dependenceGraph;
- /// Options to control tiling.
- LinalgTilingOptions tilingOptions;
- /// Options to control fusion.
- LinalgFusionOptions fusionOptions;
- /// Marker to control application of the pattern.
- LinalgTransformationFilter filter;
- /// Marker set on the fused op after tile and fuse.
- LinalgTransformationFilter fusedOpMarker;
- /// The dependenceGraph is not modifiable, i.e. if the Linalg operations used
- /// to build the dependence graph changes then the dependenceGraph needs to be
- /// recomputed right now. To not invalidate the dependenceGraph as
- /// transformation happens, the original producer can be tagged with a filter
- /// that can be later used to delete the original operations.
- LinalgTransformationFilter originalOpMarker;
-};
-
-template <typename OpTy>
-struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
- LinalgTileAndFusePattern(
- MLIRContext *context, const LinalgDependenceGraph &dependenceGraph,
- LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
- LinalgTransformationFilter originalOpMarker =
- LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : LinalgBaseTileAndFusePattern(
- OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
- fusionOptions, f, fusedOpMarker, originalOpMarker, benefit) {}
-};
-
///
/// Linalg tile and fuse tensor ops pattern.
///
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index a7076911b59d..91089a382132 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -460,436 +460,3 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
consumerOpOperand.set(def);
return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
}
-
-/// Prune all dimensions that are of reduction iterator type from `map`.
-static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
- AffineMap map) {
- llvm::SmallBitVector projectedDims(iteratorTypes.size());
- for (const auto &attr : llvm::enumerate(iteratorTypes)) {
- if (!isParallelIterator(attr.value()))
- projectedDims.set(attr.index());
- }
- return getProjectedMap(map, projectedDims);
-}
-
-/// Returns the mapping from iterations in the consumer that write to the same
-/// location as the iterations in the producer. To do so use
-/// - indexing map of the fused view in the consumer : consumerIndexMap
-/// - indexing map of the fused view in the producer : producerIndexMap
-/// consumerLoopToProducerLoop =
-/// inverse(producerIndexMap).compose(consumerIndexMap)
-static FailureOr<AffineMap> getConsumerLoopToProducerLoopMap(
- LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
- auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp());
- if (!producer)
- return failure();
-
- Optional<AffineMap> producerIndexingMap =
- dependence.getDependentOpViewIndexingMap();
- Optional<AffineMap> consumerIndexingMap =
- dependence.getIndexingOpViewIndexingMap();
- if (!producerIndexingMap || !consumerIndexingMap)
- return failure();
-
- AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
- producer.iterator_types().getValue(), *producerIndexingMap);
- if (!prunedProducerIndexingMap.isPermutation())
- return failure();
-
- if (consumerIndexingMap->getNumResults() !=
- prunedProducerIndexingMap.getNumResults())
- return failure();
-
- LLVM_DEBUG({
- llvm::dbgs() << "\t producerMap : ";
- producerIndexingMap->print(llvm::dbgs());
- llvm::dbgs() << " pruned : ";
- prunedProducerIndexingMap.print(llvm::dbgs());
- llvm::dbgs() << "\n";
- llvm::dbgs() << "\t consumerMap : ";
- consumerIndexingMap->print(llvm::dbgs());
- llvm::dbgs() << "\n";
- });
-
- AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
- if (!invProducerIndexMap)
- return failure();
-
- return invProducerIndexMap.compose(*consumerIndexingMap);
-}
-
-/// Given a projected permutation `map`, returns true if the map changes the
-/// order in which the fused loop dimension appear.
-static bool doesTransposeAccess(AffineMap map,
- const std::set<unsigned> &fusableLoops) {
- Optional<unsigned> lastFusableLoop;
- for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
- return expr.cast<AffineDimExpr>().getPosition();
- })) {
- if (!fusableLoops.count(pos))
- continue;
- if (!lastFusableLoop) {
- lastFusableLoop = pos;
- continue;
- }
- if (pos <= *lastFusableLoop)
- return true;
- lastFusableLoop = pos;
- }
- return false;
-}
-
-/// Returns the positions of the loop in `op` that can be tiled based on the
-/// operations that are to be fused with it. For example, in a
-///
-/// linalg.matmul ins(%a, %b : ...) outs(%c : ...)
-///
-/// if the producer of %a needs to be fused with this op, only the `i` loop of
-/// the matmul can be tiled while fusing. If producer of %a, and %b are to be
-/// fused, then no loops can be tiled while fusing. The conditions used are:
-/// 1. Only parallel loops can be used for tile + fuse. Find the number of
-/// common outer parallel loops between the op and its producers being fused.
-/// 2. Of the parallel loops only some can be fused. Only those loops can be
-/// fused such where the fusable loops iteration space only touches one tile
-/// of the fused operation. This is because the producer (which is writing
-/// the fused subview) has update semantics.
-///
-/// Since an inverse computation is needed, we need to consider the projection
-/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops
-/// are the dimensions of the consumerLoopToProducerLoop map that correspond to
-/// parallel loops and appear in the result of the map
-///
-/// Example 1:
-/// linalg.fill(%cst, %c)
-/// linalg.matmul ins(%a, %b) outs(%c)
-/// Number of parallel loops : 2
-/// producerIndexMap = affine_map<(i, j) ->(i , j)>
-/// consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
-/// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
-/// Fused dimensions : i, j
-///
-/// Example 2:
-/// linalg.matmul ins(%a, %b) outs(%c)
-/// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
-/// iterator_types = ["parallel", "parallel"]}
-/// ins(%c) ...
-///
-/// Number of parallel loops = 2:
-/// producerIndexMap (projected to parallel loops) =
-/// affine_map<(i, j) -> (i, j)>
-/// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
-/// Fused dimensions : i, j
-///
-/// Example 3:
-/// memref.copy(%s, %b)
-/// linalg.matmul ins(%a, %b) outs(%c)
-///
-/// Number of parallel loops = 2
-/// produceIndexMap : affine_map<(i, j) -> (i, j)>
-/// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
-/// submap with only parallel loops = affine_map<(i, j) -> (j)>
-/// Fused dimensions : j
-static std::set<unsigned>
-collectFusableLoops(ArrayRef<LinalgOp> ops,
- const FusableOpDependencesTy &fusableDependences) {
- assert(!ops.empty());
- auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
- return linalgOp.iterator_types()
- .getValue()
- .take_while([](Attribute attr) -> bool {
- return attr.cast<StringAttr>().getValue() ==
- getParallelIteratorTypeName();
- })
- .size();
- };
-
- size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
- for (auto op : ops.drop_back()) {
- numOuterParallelLoops =
- std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
- }
-
- std::set<unsigned> fusableLoops;
- auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
- fusableLoops.insert(range.begin(), range.end());
-
- for (auto op : reverse(ops)) {
- for (auto dependence : fusableDependences.lookup(op)) {
- LLVM_DEBUG({
- llvm::dbgs() << "\t fusable :";
- for (unsigned i : fusableLoops)
- llvm::dbgs() << " " << i;
- llvm::dbgs() << "\n";
- });
-
- Optional<AffineMap> consumerLoopToProducerLoop =
- getConsumerLoopToProducerLoopMap(dependence);
- if (!consumerLoopToProducerLoop) {
- op.emitRemark("failed to get map from consumer loop to producer loop");
- return {};
- }
- // todo: This condition is only an implementation limitation. When fusing
- // the operation, if the accesses in the producer/consumer are transposes
- // of each other, the loop bounds for the tiled producer can be
- // manipulated accordingly. This requires some additional bookkeeping in
- // the implementation of tile+fuse that is deferred to later.
- if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
- op.emitRemark("unhandled fusion when fusion requires permutation");
- return {};
- }
-
- std::set<unsigned> candidates;
- for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
- unsigned position = expr.cast<AffineDimExpr>().getPosition();
- if (fusableLoops.count(position))
- candidates.insert(position);
- }
- LLVM_DEBUG({
- llvm::dbgs() << "\t candidates :";
- for (unsigned i : candidates)
- llvm::dbgs() << " " << i;
- llvm::dbgs() << "\n";
- });
- if (candidates.empty())
- return {};
- std::swap(candidates, fusableLoops);
- }
- }
-
- return fusableLoops;
-}
-
-/// Find all dependences that are fusable.
-FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
- ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
- FusableOpDependencesTy fusableDependences;
- DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
- for (LinalgOp op : reverse(ops)) {
- for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
- Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
- fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
- if (!fusableDependence)
- continue;
- LinalgOp producerOp =
- dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
- if (!producerOp)
- continue;
- // Do not fuse dependences that are to operations not in the same basic
- // block. This avoid moving fused operations across loops that might
- // themselves carry dependency making the fusion illegal.
- if (producerOp->getBlock() != op->getBlock())
- continue;
-
- // Make sure that the indexing map of the view used for fusion in the
- // producer is a projected permutation.
- Optional<AffineMap> producerMap =
- fusableDependence->getDependentOpViewIndexingMap();
- Optional<AffineMap> consumerMap =
- fusableDependence->getIndexingOpViewIndexingMap();
- assert(
- consumerMap &&
- "unable to find indexing map of operand/result of indexing OpView");
- fusedProducerIndexingMap[producerOp.getOperation()].push_back(
- *consumerMap);
- if (!producerMap || !producerMap->isProjectedPermutation() ||
- !consumerMap->isProjectedPermutation())
- continue;
-
- fusableDependences[producerOp.getOperation()].push_back(
- *fusableDependence);
- }
- }
- // TODO: Currently fusion would not be legal if the fusable dependence is to
- // the same producer but
diff erent indexing map in the consumer. Fix this, but
- // in the meanwhile disallow such a fusion.
- for (auto useIndexingMapsList : fusedProducerIndexingMap) {
- AffineMap map1 = useIndexingMapsList.second.front();
- for (AffineMap map2 :
- ArrayRef<AffineMap>(useIndexingMapsList.second).drop_front()) {
- if (map1 != map2) {
- fusableDependences.erase(useIndexingMapsList.first);
- break;
- }
- }
- }
- return fusableDependences;
-}
-
-/// Tile the fused loops in the root operation, by setting the tile sizes for
-/// all other loops to zero (those will be tiled later).
-static FailureOr<TiledLinalgOp>
-tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector,
- const LinalgTilingOptions &options,
- const std::set<unsigned> &fusedLoops) {
- SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
- auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
- for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
- if (!fusedLoops.count(i))
- tileSizes[i] = zero;
- LinalgTilingOptions tileFusedLoopsOptions = options;
- tileFusedLoopsOptions.setTileSizes(tileSizes);
- // TODO: Propagate RewriterBase everywhere.
- IRRewriter rewriter(b);
- return tileLinalgOp(rewriter, op, tileFusedLoopsOptions);
-}
-
-/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
-/// to be a tiled operation such that it is valid to fuse all operations in
-/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
-/// `tiledOp`.
-static SmallVector<LinalgOp, 1>
-fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
- ArrayRef<LinalgOp> fusionCandidates,
- const FusableOpDependencesTy &fusableDependences,
- const std::set<unsigned> &fusedLoops) {
- LinalgOp tiledOp = tiledLinalgOp.op;
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPoint(tiledOp);
-
- DenseMap<unsigned, Range> fusedLoopsAndRanges;
- for (unsigned loop : fusedLoops) {
- ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
- fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
- b, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
- }
-
- SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
- DenseMap<Operation *, LinalgOp> origOpToFusedOp;
- origOpToFusedOp[rootOp.getOperation()] = tiledOp;
- for (const auto &candidate : enumerate(llvm::reverse(fusionCandidates))) {
- LinalgOp origOp = candidate.value();
- LinalgOp fusedOp = fuse(b, origOp, fusedLoopsAndRanges);
- origOpToFusedOp[origOp.getOperation()] = fusedOp;
- fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
-
- // Prepare the builder for the next insertion point.
- auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); });
- if (!origOp.hasTensorSemantics())
- continue;
-
- // If the producer consumer operations are linalg operations on tensors, the
- // dependence is due to value produced (as a return tensor) by the producer
- // and used in the consumer. The returned value of the fused op needs to be
- // made the operand of the tiled/fused consumer operation. By construction
- // the value returned by the producer is the value used by the consumer.
- for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) {
- if (dependence.dependenceType !=
- LinalgDependenceGraph::DependenceType::RAW)
- continue;
-
- unsigned resultIndex = dependence.getDependentOpViewResultNum().value();
- LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
- if (!consumer)
- continue;
-
- Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
- consumer.getOperation()->setOperand(
- dependence.getIndexingOpViewOperandNum().value(), replacementValue);
- }
-
- // At this point, all Linalg uses of the tensors produced by `origOp` have
- // been replaced. However, there may still be "output tensor"-like uses
- // coming from WAW dependencies.
- // All these uses are iter_args of the outermost loop (TODO: add a check).
- // Such iter_args uses serve 2 purposes:
- // 1. give a shape to the output
- // 2. encode destructive updates that may be inplaceable by bufferization.
- // To keep the second type of information while letting the unfused op die
- // unused, we need to forward the producer output operand.
- if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
- for (auto &operand : forOp.getIterOpOperands()) {
- if (auto opResult = operand.get().dyn_cast<OpResult>()) {
- if (opResult.getOwner() == origOp) {
- Value output =
- origOp.getOutputOperand(opResult.getResultNumber())->get();
- assert(output.getType().isa<RankedTensorType>());
- operand.set(output);
- }
- }
- }
- }
- }
- return fusedOps;
-}
-
-static FailureOr<TiledAndFusedLinalgOps>
-tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
- const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions) {
- if (ops.size() < 2)
- return failure();
- LinalgOp rootOp = ops.back();
- if (!llvm::all_of(
- ops,
- [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) &&
- !llvm::all_of(ops, [](LinalgOp linalgOp) {
- return linalgOp.hasTensorSemantics();
- })) {
- rootOp.emitError(
- "unable to fuse operations that have tensor semantics with operations "
- "that have buffer semantics and viceversa.");
- return failure();
- }
- // TODO: Support interchange with tile + fuse. This might actually help do
- // better fusion.
- if (!tilingOptions.interchangeVector.empty()) {
- rootOp.emitRemark("unable to handle tile and fuse with interchange");
- return failure();
- }
-
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPoint(rootOp);
-
- // Find all the producers.
- LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n");
- FusableOpDependencesTy fusableDependences =
- findAllFusableDependences(ops, dependenceGraph);
- if (fusableDependences.empty()) {
- LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n");
- return failure();
- }
-
- TiledAndFusedLinalgOps ret;
- // Find the loops that can be tiled and fused.
- LLVM_DEBUG(llvm::dbgs() << "collectFusableLoops\n");
- ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
-
- // If there are no fusable dependences or there are no tile+fusable loops,
- // just return.
- if (ret.fusedLoopDims.empty()) {
- LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n");
- return failure();
- }
-
- // Tile the fused loops in the last operation in the list.
- SmallVector<Value, 4> tileSizeVector =
- tilingOptions.tileSizeComputationFunction(b, rootOp);
- FailureOr<TiledLinalgOp> tiledRootOp = tileRootOperation(
- b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
- if (failed(tiledRootOp)) {
- rootOp.emitRemark("failed to tile the fused loops");
- return failure();
- }
- ret.op = tiledRootOp->op;
- ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
-
- // Fuse the other operations into the fused inter-tile loops produced above.
- ret.fusedProducers = fuseOperations(b, rootOp, *tiledRootOp, ops.drop_back(),
- fusableDependences, ret.fusedLoopDims);
-
- return ret;
-}
-
-FailureOr<TiledAndFusedLinalgOps>
-mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops,
- const LinalgDependenceGraph &dependenceGraph,
- const LinalgTilingOptions &tilingOptions) {
- switch (tilingOptions.loopType) {
- case LinalgTilingLoopType::Loops:
- case LinalgTilingLoopType::ParallelLoops:
- case LinalgTilingLoopType::TiledLoops:
- return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions);
- default:;
- }
- return failure();
-}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index f90a7c08669b..582c2cee9d07 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -350,117 +350,6 @@ void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
}
}
-static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
- if (tiledOp.loops.empty())
- return tiledOp.op.getOperation()->getResults();
- return tiledOp.loops.front()->getResults();
-}
-
-static ValueRange
-getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
- if (tiledAndFusedOp.fusedLoops.empty())
- return tiledAndFusedOp.op.getOperation()->getResults();
- return tiledAndFusedOp.fusedLoops.front()->getResults();
-}
-
-mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
- StringRef opName, MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
- LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
- LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
- : RewritePattern(opName, benefit, context, {}),
- dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
- fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
- fusedOpMarker(std::move(fusedOpMarker)),
- originalOpMarker(std::move(originalOpMarker)) {}
-
-LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
- Operation *op, PatternRewriter &rewriter) const {
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- // TODO: remove hasIndexSemantics check once index ops are supported.
- if (!linalgOp || linalgOp.hasIndexSemantics())
- return failure();
- if (failed(filter.checkAndNotify(rewriter, linalgOp)))
- return failure();
-
- DenseSet<Operation *> producers;
- producers.insert(linalgOp);
- for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
- Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
- // When looking at dependences into, indexingOp is always OpOperand. We
- // could assert, but continue if this is not the case.
- if (!operandNumber)
- continue;
- if (!fusionOptions.indicesToFuse.count(*operandNumber))
- continue;
- if (isa<LinalgOp>(dependence.getDependentOp()))
- producers.insert(dependence.getDependentOp());
- }
-
- SmallVector<LinalgOp, 1> fusionOps;
- for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
- ++it) {
- auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
- if (producerLinalgOp && producers.count(producerLinalgOp))
- fusionOps.push_back(producerLinalgOp);
- }
- fusionOps.push_back(linalgOp);
-
- SmallVector<Value, 4> tileSizes =
- tilingOptions.tileSizeComputationFunction(rewriter, op);
- LinalgTilingOptions instanceTilingOptions = tilingOptions;
- instanceTilingOptions.setTileSizes(tileSizes);
- Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
- rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
- if (!tiledAndFusedOps)
- return failure();
-
- // Tile the unfused loops;
- SmallVector<Value, 4> unfusedLoopTileSizes;
- Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
- for (const auto &tileSize : enumerate(tileSizes)) {
- if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
- unfusedLoopTileSizes.push_back(zero);
- else
- unfusedLoopTileSizes.push_back(tileSize.value());
- }
- // Tile the loop only if there is a non-zero tile size.
- if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
- unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
- if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
- if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
- return cst.value() != 0;
- return true;
- })) {
- LinalgTilingOptions unfusedTilingOptions = tilingOptions;
- unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
- FailureOr<TiledLinalgOp> unfusedTiledOp =
- tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
- if (failed(unfusedTiledOp))
- return failure();
- rewriter.replaceOp(tiledAndFusedOps->op,
- getTiledOpResult(unfusedTiledOp.value()));
- tiledAndFusedOps->op = unfusedTiledOp->op;
- }
- op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.value()));
-
- filter.replaceLinalgTransformationFilter(rewriter,
- tiledAndFusedOps->op.getOperation());
- for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
- fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
- fusedOp.getOperation());
- }
- for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
- originalOpMarker.replaceLinalgTransformationFilter(
- rewriter, origProducerOp.getOperation());
- }
- rewriter.updateRootInPlace(op, [&]() {
- originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
- });
- return success();
-}
-
/// Linalg tiling pattern.
mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
MLIRContext *context, LinalgTilingOptions options,
diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
deleted file mode 100644
index 787eff6f39ae..000000000000
--- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir
+++ /dev/null
@@ -1,307 +0,0 @@
-// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s
-
-module {
- func.func @basic_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>) {
- %cst = arith.constant 0.000000e+00 : f32
- linalg.fill ins(%cst : f32) outs(%arg2 : memref<?x?xf32>)
- linalg.matmul {__internal_linalg_transform__ = "basic_fusion"}
- ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg2 : memref<?x?xf32>)
- return
- }
-}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 64)>
-// CHECK: func @basic_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
-// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
-// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0{{.*}} : f32
-// CHECK-DAG: linalg.fill
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"
-// CHECK-SAME: ins(%[[CST]]{{.*}}outs(%[[ARG2]]
-// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) =
-// CHECK-SAME: to (%[[M]], %[[N]])
-// CHECK-SAME: step (%[[C32]], %[[C64]]) {
-// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-// CHECK: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[SV1:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K]]]
-// CHECK: %[[K_2:.+]] = memref.dim %[[ARG1]], %[[C0]]
-// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N]]]
-// CHECK: %[[SV2:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]]
-// CHECK-SAME: %[[K_2]], %[[TILE_N]]
-// CHECK: %[[SV3:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
-// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
-// CHECK: %[[M_2:.+]] = memref.dim %[[ARG2]], %[[C0]]
-// CHECK: %[[N_2:.+]] = memref.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_2]], %[[M]]]
-// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP5]](%[[IV1]])[%[[N_2]], %[[N]]]
-// CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
-// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N_3]]]
-// CHECK: linalg.fill
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
-// CHECK-SAME: ins(%[[CST]]{{.*}}outs(%[[SV3_2]]
-// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
-// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
-// CHECK: %[[SV4:.+]] = memref.subview %[[SV1]][0, %[[IV2]]]
-// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]]
-// CHECK: %[[SV5:.+]] = memref.subview %[[SV2]][%[[IV2]], 0]
-// CHECK-SAME: [%[[TILE_K]], %[[TILE_N]]]
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion"
-// CHECK-SAME: ins(%[[SV4]], %[[SV5]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME: outs(%[[SV3]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: }
-// CHECK: }
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"
-
-// -----
-
-module {
- func.func @matmul_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
- %arg4: memref<?x?xf32>) {
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg2 : memref<?x?xf32>)
- linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
- ins(%arg2, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg4 : memref<?x?xf32>)
- return
- }
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)>
-// CHECK: func @matmul_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
-// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original"
-// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG2]], %[[C0]]
-// CHECK: scf.parallel (%[[IV0:.+]]) =
-// CHECK-SAME: (%[[C0]]) to (%[[M]]) step (%[[C32]]) {
-// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-// CHECK: %[[K2:.+]] = memref.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[SV1:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K2]]]
-// CHECK: %[[N:.+]] = memref.dim %[[ARG4]], %[[C1]]
-// CHECK: %[[SV2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N]]]
-// CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_3]], %[[M]]]
-// CHECK: %[[K1:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_3]], %[[K1]]]
-// CHECK: %[[SV1_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_3]], %[[K2]]]
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
-// CHECK-SAME: ins(%[[SV3]], %[[ARG1]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
-// CHECK-SAME: outs(%[[SV1_2]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
-// CHECK: scf.parallel (%[[IV1:.+]]) =
-// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
-// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K2]] step %[[C16]] {
-// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K2]]]
-// CHECK: %[[SV6:.+]] = memref.subview %[[SV1]][0, %[[IV2]]]
-// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]]
-// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]]
-// CHECK: %[[SV7:.+]] = memref.subview %[[ARG3]][%[[IV2]], %[[IV1]]]
-// CHECK-SAME: [%[[TILE_K]], %[[TILE_N]]]
-// CHECK: %[[SV8:.+]] = memref.subview %[[SV2]][0, %[[IV1]]]
-// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion"
-// CHECK-SAME: ins(%[[SV6]], %[[SV7]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
-// CHECK-SAME: outs(%[[SV8]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original"
-
-// -----
-
-module {
- func.func @matmul_plus_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = memref.dim %arg2, %c0 : memref<?x?xf32>
- %1 = memref.dim %arg2, %c1 : memref<?x?xf32>
- %2 = memref.alloc(%0, %1) : memref<?x?xf32>
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%2 : memref<?x?xf32>)
- linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"],
- __internal_linalg_transform__ = "transpose_fusion"}
- ins(%2, %2 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg2 : memref<?x?xf32>) {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
- %3 = arith.addf %arg3, %arg4 : f32
- linalg.yield %3 : f32
- }
- return
- }
-}
-// CHECK: func @matmul_plus_matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK: %[[T2:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
-// CHECK: linalg.matmul
-// CHECK-SAME: after_transpose_fusion_original
-// CHECK: scf.parallel (%[[ARG3:[a-zA-Z0-9_]+]], %[[ARG4:.[a-zA-Z0-9_]+]])
-// CHECK: %[[T5:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]]
-// CHECK: %[[T6:.+]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
-// CHECK: %[[T8:.+]] = memref.subview %[[ARG0]][%[[ARG3]], 0]
-// CHECK: %[[T9:.+]] = memref.subview %[[ARG1]][0, %[[ARG4]]]
-// CHECK: %[[T10:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]]
-// CHECK: linalg.matmul
-// CHECK-SAME: after_transpose_fusion_producer
-// CHECK-SAME: ins(%[[T8]], %[[T9]]
-// CHECK-SAME: outs(%[[T10]]
-// CHECK-NOT: linalg.matmul
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[T5]], %[[T5]]
-// CHECK-SAME: outs(%[[T6]]
-// CHECK-SAME: after_transpose_fusion
-
-// -----
-
-module {
- func.func @matmul_plus_transpose_matmul(%arg0: memref<?x?xf32>,
- %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = memref.dim %arg2, %c0 : memref<?x?xf32>
- %1 = memref.dim %arg2, %c1 : memref<?x?xf32>
- %2 = memref.alloc(%0, %1) : memref<?x?xf32>
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%2 : memref<?x?xf32>)
- linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d1, d0)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"],
- __internal_linalg_transform__ = "transpose_fusion"}
- ins(%2, %2 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg2 : memref<?x?xf32>) {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
- %3 = arith.addf %arg3, %arg4 : f32
- linalg.yield %3 : f32
- }
- return
- }
-}
-// CHECK-LABEL: func @matmul_plus_transpose_matmul
-// CHECK-NOT: scf.parallel
-// CHECK-NOT: scf.for
-// CHECK: linalg.matmul
-// CHECK-NOT: scf.parallel
-// CHECK-NOT: scf.for
-// CHECK: linalg.generic
-// CHECK-NOT: scf.parallel
-// CHECK-NOT: scf.for
-
-// -----
-
-#map0 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
-#map1 = affine_map<(d0)[s0] -> (64, -d0 + s0)>
-#map2 = affine_map<(d0)[s0] -> (16, -d0 + s0)>
-#map3 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-module {
- func.func @basic_no_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c32 = arith.constant 32 : index
- %c64 = arith.constant 64 : index
- %c16 = arith.constant 16 : index
- %cst = arith.constant 0.000000e+00 : f32
- linalg.fill ins(%cst : f32) outs(%arg2 : memref<?x?xf32>)
- %0 = memref.dim %arg0, %c0 : memref<?x?xf32>
- %1 = memref.dim %arg1, %c1 : memref<?x?xf32>
- %2 = memref.dim %arg0, %c1 : memref<?x?xf32>
- scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c32, %c64) {
- scf.for %arg5 = %c0 to %2 step %c16 {
- %3 = affine.min #map0(%arg3)[%0]
- %4 = affine.min #map1(%arg4)[%1]
- %5 = affine.min #map2(%arg5)[%2]
- %6 = memref.subview %arg0[%arg3, %arg5] [%3, %5] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
- %7 = memref.subview %arg1[%arg5, %arg4] [%5, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
- %8 = memref.subview %arg2[%arg3, %arg4] [%3, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map3>
- linalg.matmul {__internal_linalg_transform__ = "basic_fusion"}
- ins(%6, %7 : memref<?x?xf32, #map3>, memref<?x?xf32, #map3>)
- outs(%8 : memref<?x?xf32, #map3>)
- }
- scf.yield
- }
- return
- }
-}
-// CHECK-LABEL: func @basic_no_fusion
-// CHECK-NOT: scf.parallel
-// CHECK: linalg.fill
-// CHECK: scf.parallel
-// CHECK: scf.for
-// CHECK-NOT: linalg.fill
-// CHECK: linalg.matmul
-
-// -----
-
-module {
- func.func @basic_conv_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>) {
- %cst = arith.constant 0.000000e+00 : f32
- linalg.fill ins(%cst : f32) outs(%arg2 : memref<?x?xf32>)
- linalg.conv_2d {__internal_linalg_transform__ = "basic_fusion"}
- ins(%arg1, %arg0 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>)
- return
- }
-}
-// CHECK: func @basic_conv_fusion
-// CHECK: linalg.fill
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"
-// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
-// CHECK-SAME: {
-// CHECK: linalg.fill
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
-// CHECK: linalg.conv_2d
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion"
-// CHECK: }
-// CHECK: linalg.conv_2d
-// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original"
diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
deleted file mode 100644
index ffe85804a309..000000000000
--- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir
+++ /dev/null
@@ -1,252 +0,0 @@
-// RUN: mlir-opt -pass-pipeline="func.func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),resolve-shaped-type-result-dims,canonicalize,cse" -split-input-file %s | FileCheck %s
-
-module {
- func.func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?xf32>, %arg3 : memref<?x?xf32>) {
- %cst = arith.constant 0.000000e+00 : f32
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %d0 = memref.dim %arg0, %c0 : memref<?x?xf32>
- %d1 = memref.dim %arg1, %c1 : memref<?x?xf32>
- %0 = memref.alloc(%d0, %d1) : memref<?x?xf32>
- linalg.fill ins(%cst : f32) outs(%0 : memref<?x?xf32>)
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%0 : memref<?x?xf32>)
- linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%0, %arg2 : memref<?x?xf32>, memref<?xf32>)
- outs(%arg3 : memref<?x?xf32>) {
- ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) :
- %5 = arith.addf %arg4, %arg5 : f32
- linalg.yield %5 : f32
- }
- return
- }
-}
-
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK: func @three_op_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK: %[[TEMP:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
-// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
-// CHECK: %[[SV_TEMP_1:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
-// CHECK-DAG: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][%[[IV1]]]
-// CHECK-DAG: %[[SV_ARG3:.+]] = memref.subview %[[ARG3]][%[[IV0]], %[[IV1]]]
-// CHECK-DAG: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-DAG: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]]
-// CHECK: %[[SV_TEMP_2:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
-// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_TEMP_1]]
-// CHECK: linalg.matmul
-// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
-// CHECK-SAME: outs(%[[SV_TEMP_2]] : memref<?x?xf32, #[[MAP2]]>)
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[SV_TEMP_1]], %[[SV_ARG2]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
-// CHECK-SAME: outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
-// CHECK: scf.yield
-// CHECK: }
-
-// -----
-
-module {
- func.func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
- %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
- %arg4: memref<?x?xf32>) {
- %cst = arith.constant 0.000000e+00 : f32
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %m = memref.dim %arg0, %c0 : memref<?x?xf32>
- %n1 = memref.dim %arg1, %c1 : memref<?x?xf32>
- %n2 = memref.dim %arg2, %c1 : memref<?x?xf32>
- %n3 = memref.dim %arg3, %c1 : memref<?x?xf32>
- %0 = memref.alloc(%m, %n1) : memref<?x?xf32>
- %1 = memref.alloc(%m, %n2) : memref<?x?xf32>
- linalg.fill ins(%cst : f32) outs(%0 : memref<?x?xf32>)
- linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%0 : memref<?x?xf32>)
- linalg.fill ins(%cst : f32) outs(%1 : memref<?x?xf32>)
- linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%1 : memref<?x?xf32>)
- linalg.fill ins(%cst : f32) outs(%arg4 : memref<?x?xf32>)
- linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
- outs(%arg4 : memref<?x?xf32>)
- return
- }
-}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)>
-
-
-// CHECK: func @sequence_of_matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
-// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[N1:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[N2:.+]] = memref.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[ALLOC1:.+]] = memref.alloc(%[[M]], %[[N1]])
-// CHECK: %[[ALLOC2:.+]] = memref.alloc(%[[M]], %[[N2]])
-// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
-// CHECK-SAME: step (%[[C16]]) {
-// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-// CHECK: %[[SV_ALLOC3:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
-// CHECK: %[[N3:.+]] = memref.dim %[[ARG4]], %[[C1]]
-// CHECK: %[[SV_ARG4:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N3]]]
-// CHECK: %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]]
-// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M_2]], %[[M]]]
-// CHECK: %[[SV_ARG4_2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_3]], %[[N3]]]
-// CHECK: %[[SV_ALLOC1:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
-// CHECK: %[[TILE_M_5:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]], %[[M]]]
-// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_5]], %[[N0]]]
-// CHECK: %[[SV_ALLOC4:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_5]], %[[N1]]]
-// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC1]]
-// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[ARG1]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
-// CHECK-SAME: outs(%[[SV_ALLOC4]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC3]]
-// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[ARG2]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
-// CHECK-SAME: outs(%[[SV_ALLOC3]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ARG4_2]]
-// CHECK: linalg.matmul ins(%[[SV_ALLOC3]], %[[ARG3]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
-// CHECK-SAME: outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK: scf.yield
-// CHECK: }
-
-
-// -----
-
-module {
- func.func @tensor_op_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %arg2: tensor<?x?xf32>, %arg3: tensor<?xf32>)
- -> tensor<?x?xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
- %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
- %3 = linalg.init_tensor [%1, %2] : tensor<?x?xf32>
- %4 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%0, %arg3 : tensor<?x?xf32>, tensor<?xf32>)
- outs(%3 : tensor<?x?xf32>) {
- ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
- %5 = arith.addf %arg4, %arg5 : f32
- linalg.yield %5 : f32
- } -> tensor<?x?xf32>
- return %4 : tensor<?x?xf32>
- }
-}
-// CHECK-LABEL: func @tensor_op_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?xf32>
-// CHECK: %[[INIT:.+]] = linalg.init_tensor
-// CHECK: %[[R0:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG5:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[R1:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG7:.+]] = %[[ARG5]]) -> (tensor<?x?xf32>) {
-// CHECK-DAG: %[[STARG3:.+]] = tensor.extract_slice %[[ARG3]]
-// CHECK-DAG: %[[STARG7:.+]] = tensor.extract_slice %[[ARG7]]
-// CHECK-DAG: %[[STARG0:.+]] = tensor.extract_slice %[[ARG0]]
-// CHECK-DAG: %[[STARG1:.+]] = tensor.extract_slice %[[ARG1]]
-// CHECK-DAG: %[[STARG2:.+]] = tensor.extract_slice %[[ARG2]]
-// CHECK: %[[T0:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[STARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK: %[[T1:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[T0:.+]], %[[STARG3]] : tensor<?x?xf32>, tensor<?xf32>)
-// CHECK-SAME: outs(%[[STARG7]] : tensor<?x?xf32>)
-// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[T1]] into %[[ARG7]]
-// CHECK: scf.yield %[[RESULT]]
-// CHECK: }
-// CHECK: scf.yield %[[R1]]
-// CHECK: }
-// CHECK: return %[[R0]]
-
-// -----
-
-module {
- func.func @tensor_matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
- %arg4: tensor<?x?xf32>, %arg5: tensor<?x?xf32>,
- %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1]
- %1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
- %2 = linalg.matmul ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
- return %2 : tensor<?x?xf32>
- }
-}
-
-// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
-// CHECK: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)>
-
-// CHECK: func @tensor_matmul_fusion(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[M:.+]] = tensor.dim %[[ARG0]], %c0 : tensor<?x?xf32>
-// CHECK: %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] =
-// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[TILE_M_1:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-// CHECK: %[[N3:.+]] = tensor.dim %[[ARG8]], %[[C1]]
-// CHECK: %[[STARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_1]], %[[N3]]]
-// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M]], %[[M]]]
-// CHECK: %[[N2:.+]] = tensor.dim %[[ARG4]], %[[C1]]
-// CHECK: %[[STARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_2]], %[[N2]]]
-// CHECK: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[STARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_2]], %[[N0]]]
-// CHECK: %[[N1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[STARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_2]], %[[N1]]]
-// CHECK: %[[T0:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[STARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>
-// CHECK-SAME: ) outs(%[[STARG2]] : tensor<?x?xf32>)
-// CHECK: %[[T1:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[T0]], %arg3 : tensor<?x?xf32>, tensor<?x?xf32>
-// CHECK-SAME: ) outs(%[[STARG4]] : tensor<?x?xf32>)
-// CHECK: %[[T2:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[T1]], %arg5 : tensor<?x?xf32>, tensor<?x?xf32>
-// CHECK-SAME: ) outs(%[[STARG6]] : tensor<?x?xf32>)
-// CHECK: %[[R1:.+]] = tensor.insert_slice %[[T2]]
-// CHECK-SAME: into %[[ARG8]][%[[IV0]], 0] [%[[TILE_M_1]], %[[N3]]]
-// CHECK: scf.yield %[[R1]] : tensor<?x?xf32>
-// CHECK: }
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
deleted file mode 100644
index 56f4c9d62817..000000000000
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ /dev/null
@@ -1,193 +0,0 @@
-// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s
-
-module {
- func.func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
- %AB_init: tensor<?x?xf32>, %C: tensor<?x?xf32>,
- %ABC_init: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%AB_init : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN1> <N1xN2>
- %ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
- ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%ABC_init : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN2> <N2xN3>
- return %ABC : tensor<?x?xf32>
- }
-}
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)>
-
-// CHECK: func @matmul_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
-// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
-// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] =
-// CHECK-SAME: %[[C0]] to %[[M]] step %[[C32]]
-// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M]]]
-// CHECK: %[[N3:.+]] = tensor.dim %[[ARG6]], %[[C1]]
-// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
-// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP5]](%[[IV0]])[%[[M]], %[[M]]]
-// CHECK: %[[N1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_3]], %[[N1]]]
-// CHECK: %[[N2_2:.+]] = tensor.dim %[[ARG2]], %[[C1]]
-// CHECK: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M_3]], %[[N2_2]]]
-// CHECK: %[[LHS:.+]] = linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
-// CHECK-SAME: ins(%[[ST_ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<?x?xf32>)
-// CHECK: %[[N2:.+]] = tensor.dim %[[ARG1]], %[[C1]]
-// CHECK: %[[N3_2:.+]] = tensor.dim %[[ARG3]], %[[C1]]
-// CHECK: %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] =
-// CHECK-SAME: %[[C0]] to %[[N3_2]] step %[[C64]]
-// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ST_ARG6]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[YIELD1:.+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] =
-// CHECK-SAME: %[[C0]] to %[[N2]] step %[[C16]]
-// CHECK-SAME: iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]]
-// CHECK: %[[ST_LHS:.+]] = tensor.extract_slice %[[LHS]][0, %[[IV2]]]
-// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N2]]]
-// CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]]
-// CHECK: %[[ST_ARG3:.+]] = tensor.extract_slice %[[ARG3]][%[[IV2]], %[[IV1]]]
-// CHECK-SAME: [%[[TILE_N2]], %[[TILE_N3]]]
-// CHECK: %[[M_4:.+]] = tensor.dim %[[ARG10]], %[[C0]]
-// CHECK: %[[ST_ARG4:.+]] = tensor.extract_slice %[[ARG10]][0, %[[IV1]]]
-// CHECK-SAME: [%[[M_4]], %[[TILE_N3]]]
-// CHECK: %[[ST_RESULT:.+]] = linalg.matmul
-// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion"
-// CHECK-SAME: ins(%[[ST_LHS]], %[[ST_ARG3]]
-// CHECK-SAME: : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[ST_ARG4]] : tensor<?x?xf32>)
-// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[ST_RESULT]]
-// CHECK-SAME: into %[[ARG10]][0, %[[IV1]]] [%[[M_4]], %[[TILE_N3]]]
-// CHECK: scf.yield %[[UPDATE1]]
-// CHECK: }
-// CHECK: scf.yield %[[YIELD1]]
-// CHECK: }
-// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[YIELD0]] into
-// CHECK-SAME: %[[ARG6]][%[[IV0]], 0] [%[[TILE_M_2]], %[[N3]]]
-// CHECK: scf.yield %[[UPDATE0]]
-// CHECK: }
-// CHECK: return %[[RESULT]]
-
-// -----
-
-module {
- func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32>
- %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32>
- %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %3 = tensor.dim %2, %c0 : tensor<?x?xf32>
- %4 = tensor.dim %2, %c1 : tensor<?x?xf32>
- %5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
- %6 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"],
- __internal_linalg_transform__ = "transpose_fusion"}
- ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%5 : tensor<?x?xf32>) {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
- %7 = arith.addf %arg3, %arg4 : f32
- linalg.yield %7 : f32
- } -> tensor<?x?xf32>
- return %6 : tensor<?x?xf32>
- }
-}
-// CHECK: func @matmul_plus_matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
-// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
-// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
-// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]])
-// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]]
-// CHECK: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
-// CHECK: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
-// CHECK: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
-// CHECK: %[[LHS:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]]
-// CHECK-SAME: : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<?x?xf32>)
-// CHECK: %[[ST_RESULT:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[LHS]] : tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[ST_ARG6]] : tensor<?x?xf32>)
-// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]]
-// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]]
-// CHECK: scf.yield %[[UPDATE]]
-// CHECK: scf.yield %[[YIELD]]
-// CHECK: return %[[RESULT]]
-
-// -----
-
-module {
- func.func @matmul_out_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %c0 = arith.constant 0.0 : f32
- %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"}
- ins(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
- }
-}
-
-// CHECK-LABEL: func @matmul_out_fusion(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK: %[[C0:.*]] = arith.constant 0.0{{.*}} : f32
-// CHECK-NOT: fill
-// CHECK: scf.for %[[I:.*]]{{.*}}iter_args(%{{.*}} = %[[ARG0]]) -> (tensor<?x?xf32>) {
-// CHECK: scf.for %[[J:.*]]
-// CHECK: %[[ST:.*]] = tensor.extract_slice %[[ARG0]]
-// CHECK: %[[ST_FILL:.*]] = linalg.fill
-// CHECK-SAME: {__internal_linalg_transform__ = "after_out_fusion_producer"}
-// CHECK-SAME: ins(%[[C0]] : f32) outs(%[[ST]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK: %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor<?x?xf32>) {
-// CHECK-NOT: fill
-// CHECK: %[[ST_FILL_SUB:.*]] = tensor.extract_slice %[[BB]][0, 0]
-// CHECK: %[[ST_MM_SUB:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ST_FILL_SUB]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK: %[[ST_MM:.*]] = tensor.insert_slice %[[ST_MM_SUB]] into %[[BB]]
-// CHECK: scf.yield %[[ST_MM]] : tensor<?x?xf32>
-// CHECK: %[[MM:.*]] = tensor.insert_slice %[[ST_MM_RES]] into {{.*}}
-// CHECK: scf.yield %[[MM]] : tensor<?x?xf32>
-
-// -----
-
-module {
- func.func @generic_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %c0 = arith.constant 0.0 : f32
- %0 = linalg.generic {
- indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%c0 : f32)
- outs(%arg0: tensor<?x?xf32>) {
- ^bb(%0: f32, %1: f32) :
- linalg.yield %0 : f32
- } -> tensor<?x?xf32>
- %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"}
- ins(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
- }
-}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 81e2bfbe2d9f..d1ca2d2c4625 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -230,3 +230,174 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
// CHECK-SAME: outs(%[[INIT_TILE_2]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
// CHECK scf.yield %[[INSERT]]
+
+// -----
+
+func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32>
+ %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = tensor.dim %2, %c0 : tensor<?x?xf32>
+ %4 = tensor.dim %2, %c1 : tensor<?x?xf32>
+ %5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
+ %6 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"],
+ __internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
+ ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%5 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
+ %7 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %7 : f32
+ } -> tensor<?x?xf32>
+ return %6 : tensor<?x?xf32>
+}
+// This fuses as expected but the gemm operation is inlined twice. It should be CSE-d but isnt today.
+
+// CHECK: func @matmul_plus_matmul
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]])
+// CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+// CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+// CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[LHS:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] :
+// CHECK-SAME: outs(%[[ST_ARG2]] :
+// CHECK-DAG: %[[ST_ARG0_1:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+// CHECK-DAG: %[[ST_ARG1_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+// CHECK-DAG: %[[ST_ARG2_1:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[RHS:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[ST_ARG0_1]], %[[ST_ARG1_1]] :
+// CHECK-SAME: outs(%[[ST_ARG2_1]] :
+// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[ST_RESULT:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[ST_ARG6]] :
+// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]]
+// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: scf.yield %[[UPDATE]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @matmul_plus_transpose_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32>
+ %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = tensor.dim %2, %c0 : tensor<?x?xf32>
+ %4 = tensor.dim %2, %c1 : tensor<?x?xf32>
+ %5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
+ %6 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"],
+ __internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
+ ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%5 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
+ %7 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %7 : f32
+ } -> tensor<?x?xf32>
+ return %6 : tensor<?x?xf32>
+}
+// CHECK: func @matmul_plus_transpose_matmul
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]])
+// CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
+// CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
+// CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[LHS:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]]
+// CHECK-SAME: : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<?x?xf32>)
+// CHECK-DAG: %[[STR_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
+// CHECK-DAG: %[[STR_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
+// CHECK-DAG: %[[STR_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV0]]]
+// CHECK: %[[RHS:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[STR_ARG0]], %[[STR_ARG1]] :
+// CHECK-SAME: outs(%[[STR_ARG2]] :
+// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[ST_RESULT:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
+// CHECK-SAME: outs(%[[ST_ARG6]] :
+// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]]
+// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: scf.yield %[[UPDATE]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>, %arg4: tensor<?x?xf32>,
+ %arg5: tensor<?x?xf32>, %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1]
+ %1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
+ %2 = linalg.matmul
+ {__internal_linalg_transform__ = "gemm_sequence_fusion"}
+ ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
+ return %2 : tensor<?x?xf32>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK: func @matmul_sequence_fusion(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[ORIG_GEMM1:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] :
+// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
+// CHECK-DAG: %[[ORIG_GEMM2:.+]] = linalg.matmul ins(%[[ORIG_GEMM1]], %[[ARG3]] :
+// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]]
+// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]]
+// CHECK-DAG: %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]]
+// CHECK: %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] =
+// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
+// CHECK-DAG: %[[TILE_M:.+]] = affine.min #[[MAP]](%[[IV]])[%{{.+}}, %[[M]]]
+// CHECK-DAG: %[[SLICE_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[TILE_M]], %[[N0]]]
+// CHECK-DAG: %[[SLICE_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
+// CHECK-DAG: %[[SLICE_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV]], 0] [%[[TILE_M]], %[[N1]]]
+// CHECK-DAG: %[[TILE_GEMM1:.+]] = linalg.matmul ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] :
+// CHECK-SAME: outs(%[[SLICE_ARG2]] :
+// CHECK-DAG: %[[SLICE_ARG3:.+]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[N1]], %[[N2]]]
+// CHECK-DAG: %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]]
+// CHECK-DAG: %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] :
+// CHECK-SAME: outs(%[[SLICE_ARG4]] :
+// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
+// CHECK-DAG: %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
+// CHECK-DAG: %[[TILE_GEMM3:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] :
+// CHECK-SAME: outs(%[[SLICE_ARG6]] :
+// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
+// CHECK: scf.yield %[[UPDATE]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 67651a98f794..c5b27c53e8cc 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -23,130 +23,6 @@
using namespace mlir;
using namespace mlir::linalg;
-/// Use this to safely fill patterns for this test, since RewritePatternSet::add
-/// forwards Rvalues only to the first pattern.
-template <typename OpTy, LinalgTilingLoopType LoopType>
-static void fillFusionPattern(MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- RewritePatternSet &patterns,
- const Twine &testCase,
- ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> indicesToFuse) {
- patterns.add<LinalgTileAndFusePattern<OpTy>>(
- context, dependenceGraph,
- LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(LoopType),
- LinalgFusionOptions().setIndicesToFuse(indicesToFuse),
- LinalgTransformationFilter(
- StringAttr::get(context, testCase + "_fusion"),
- StringAttr::get(context, "after_" + testCase + "_fusion")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_" + testCase + "_fusion_producer")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_" + testCase + "_fusion_original")));
-}
-
-template <LinalgTilingLoopType LoopType>
-static void fillFusionPatterns(MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- RewritePatternSet &patterns) {
- fillFusionPattern<Conv2DOp, LoopType>(context, dependenceGraph, patterns,
- /*testCase=*/"basic",
- /*tileSizes=*/{32, 64, 16},
- /*indicesToFuse=*/{2});
-
- auto fillMatmulPattern = [&](const Twine &testCase,
- ArrayRef<int64_t> indicesToFuse) {
- fillFusionPattern<MatmulOp, LoopType>(context, dependenceGraph, patterns,
- testCase, /*tileSizes=*/{32, 64, 16},
- indicesToFuse);
- };
- fillMatmulPattern(/*testCase=*/"basic",
- /*indicesToFuse=*/{2});
- fillMatmulPattern(/*testCase=*/"lhs",
- /*indicesToFuse=*/{0});
- fillMatmulPattern(/*testCase=*/"out",
- /*indicesToFuse=*/{2});
- fillMatmulPattern(/*testCase=*/"rhs",
- /*indicesToFuse=*/{1});
- fillMatmulPattern(/*testCase=*/"two_operand",
- /*indicesToFuse=*/{0, 2});
-
- fillFusionPattern<GenericOp, LoopType>(context, dependenceGraph, patterns,
- /*testCase=*/"transpose",
- /*tileSizes=*/{32, 64},
- /*indicesToFuse=*/{0, 1});
-}
-
-namespace {
-template <LinalgTilingLoopType LoopType>
-struct TestLinalgFusionTransforms
- : public PassWrapper<TestLinalgFusionTransforms<LoopType>,
- OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransforms)
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
- scf::SCFDialect>();
- }
- TestLinalgFusionTransforms() = default;
- TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
-
- void runOnOperation() override {
- MLIRContext *context = &this->getContext();
- func::FuncOp funcOp = this->getOperation();
- RewritePatternSet fusionPatterns(context);
- Aliases alias;
- LinalgDependenceGraph dependenceGraph =
- LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
- fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
- }
-};
-
-struct TestLinalgFusionTransformsParallelLoops
- : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestLinalgFusionTransformsParallelLoops)
-
- StringRef getArgument() const final {
- return "test-linalg-fusion-transform-patterns";
- }
- StringRef getDescription() const final {
- return "Test Linalg fusion transformation patterns by applying them "
- "greedily.";
- }
-};
-
-struct TestLinalgFusionTransformsLoops
- : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransformsLoops)
-
- StringRef getArgument() const final {
- return "test-linalg-tensor-fusion-transform-patterns";
- }
- StringRef getDescription() const final {
- return "Test Linalg on tensor fusion transformation "
- "patterns by applying them greedily.";
- }
-};
-
-struct TestLinalgFusionTransformsTiledLoops
- : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestLinalgFusionTransformsTiledLoops)
-
- StringRef getArgument() const final {
- return "test-linalg-tiled-loop-fusion-transform-patterns";
- }
- StringRef getDescription() const final {
- return "Test Linalg on tensor fusion transformation "
- "patterns by applying them greedily.";
- }
-};
-} // namespace
-
static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
OpBuilder b(f);
DenseSet<Operation *> eraseSet;
@@ -236,82 +112,13 @@ struct TestLinalgGreedyFusion
} while (succeeded(fuseLinalgOpsGreedily(getOperation())));
}
};
-
-/// Pass to test tile and fuse of sequence of operations. Intended only for
-/// testing.
-struct TestLinalgTileAndFuseSequencePass
- : public PassWrapper<TestLinalgTileAndFuseSequencePass,
- OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestLinalgTileAndFuseSequencePass)
-
- StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
- StringRef getDescription() const final {
- return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
- }
- TestLinalgTileAndFuseSequencePass() = default;
- TestLinalgTileAndFuseSequencePass(
- const TestLinalgTileAndFuseSequencePass &pass)
- : PassWrapper(pass){};
-
- ListOption<int64_t> tileSizes{*this, "tile-sizes",
- llvm::cl::desc("Tile sizes to use for ops")};
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
- scf::SCFDialect>();
- }
-
- void runOnOperation() override {
- func::FuncOp funcOp = getOperation();
- auto &blocks = funcOp.getBody().getBlocks();
- if (!llvm::hasSingleElement(blocks)) {
- return;
- }
- SmallVector<LinalgOp, 2> linalgOps =
- llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
- Aliases aliases;
- LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
- OpBuilder builder(funcOp.getContext());
- linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
- if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
- return linalgOp.hasTensorSemantics();
- }))
- loopType = LinalgTilingLoopType::Loops;
- Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
- builder, linalgOps, dependenceGraph,
- LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
- if (!tileAndFuseOps)
- return signalPassFailure();
- if (linalgOps.back().hasTensorSemantics()) {
- linalgOps.back().getOperation()->replaceAllUsesWith(
- tileAndFuseOps->fusedLoops.front());
- }
- for (auto op : linalgOps)
- if (op.hasBufferSemantics())
- op.erase();
- }
-};
-
} // namespace
namespace mlir {
namespace test {
-void registerTestLinalgFusionTransforms() {
- PassRegistration<TestLinalgFusionTransformsParallelLoops>();
-}
-void registerTestLinalgTensorFusionTransforms() {
- PassRegistration<TestLinalgFusionTransformsLoops>();
-}
-void registerTestLinalgTiledLoopFusionTransforms() {
- PassRegistration<TestLinalgFusionTransformsTiledLoops>();
-}
void registerTestLinalgGreedyFusion() {
PassRegistration<TestLinalgGreedyFusion>();
}
-void registerTestLinalgTileAndFuseSequencePass() {
- PassRegistration<TestLinalgTileAndFuseSequencePass>();
-}
} // namespace test
} // namespace mlir
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 214a4053e232..5c603a55d741 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -191,6 +191,12 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
addPatternForTiling<
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0});
+ addPatternForTiling<
+ TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
+ context, patterns, "gemm_plus_gemm_fusion", {10, 20});
+ addPatternForTiling<
+ TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
+ context, patterns, "gemm_sequence_fusion", {10});
return;
}
}
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index f8fa2459e667..78e26de1d54f 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -88,12 +88,8 @@ void registerTestInterfaces();
void registerTestLastModifiedPass();
void registerTestLinalgDecomposeOps();
void registerTestLinalgElementwiseFusion();
-void registerTestLinalgFusionTransforms();
-void registerTestLinalgTensorFusionTransforms();
-void registerTestLinalgTiledLoopFusionTransforms();
void registerTestLinalgGreedyFusion();
void registerTestLinalgHoisting();
-void registerTestLinalgTileAndFuseSequencePass();
void registerTestLinalgTransforms();
void registerTestLivenessPass();
void registerTestLoopFusion();
@@ -187,12 +183,8 @@ void registerTestPasses() {
mlir::test::registerTestLastModifiedPass();
mlir::test::registerTestLinalgDecomposeOps();
mlir::test::registerTestLinalgElementwiseFusion();
- mlir::test::registerTestLinalgFusionTransforms();
- mlir::test::registerTestLinalgTensorFusionTransforms();
- mlir::test::registerTestLinalgTiledLoopFusionTransforms();
mlir::test::registerTestLinalgGreedyFusion();
mlir::test::registerTestLinalgHoisting();
- mlir::test::registerTestLinalgTileAndFuseSequencePass();
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessPass();
mlir::test::registerTestLoopFusion();
More information about the Mlir-commits
mailing list