[Mlir-commits] [mlir] 0c090dc - [mlir][Linalg] Deprecate legacy reshape + generic op folding patterns.
Mahesh Ravishankar
llvmlistbot at llvm.org
Thu Apr 21 15:25:36 PDT 2022
Author: Mahesh Ravishankar
Date: 2022-04-21T22:25:23Z
New Revision: 0c090dcc8a97a07bb3b3d2f64dbd1abf3990c1c6
URL: https://github.com/llvm/llvm-project/commit/0c090dcc8a97a07bb3b3d2f64dbd1abf3990c1c6
DIFF: https://github.com/llvm/llvm-project/commit/0c090dcc8a97a07bb3b3d2f64dbd1abf3990c1c6.diff
LOG: [mlir][Linalg] Deprecate legacy reshape + generic op folding patterns.
These patterns have been superceded by the fusion by collapsing patterns.
Differential Revision: https://reviews.llvm.org/D124145
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
mlir/test/Dialect/Linalg/reshape_fusion.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
Removed:
mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 9f717d07d276e..06f0e217986d7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -90,31 +90,11 @@ def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
def LinalgElementwiseOpFusion : Pass<"linalg-fuse-elementwise-ops"> {
let summary = "Fuse elementwise operations on tensors";
let constructor = "mlir::createLinalgElementwiseOpFusionPass()";
- let options = [
- Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
- "bool", /*default=*/"false",
- "Allow fusing linalg.tensor_reshape ops that performs unit "
- "dimension collapsing">
- ];
let dependentDialects = [
"AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
];
}
-def LinalgFoldReshapeOpsByLinearization :
- Pass<"linalg-fold-reshape-ops-by-linearization"> {
- let summary = "Fold TensorReshapeOps with generic/indexed generic ops by "
- "linearization";
- let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()";
- let options = [
- Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
- "bool", /*default=*/"false",
- "Allow fusing linalg.tensor_reshape ops that performs unit "
- "dimension collapsing">
- ];
- let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
-}
-
def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
let summary = "Convert from one named linalg op to another.";
let constructor = "mlir::createLinalgNamedOpConversionPass()";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b4fefc21132e7..188f2b436a3d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -37,10 +37,6 @@ struct LinalgElementwiseFusionOptions;
struct LinalgFusionOptions;
struct LinalgTilingOptions;
-/// Default function to control reshape folding. Skips folding unit dimension
-/// reshapes.
-bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
-
//===----------------------------------------------------------------------===//
// Transformations exposed as function calls.
//===----------------------------------------------------------------------===//
@@ -91,24 +87,6 @@ void populateFoldReshapeOpsByCollapsingPatterns(
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
const ControlFusionFn &controlFn);
-/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic operation by linearizing the indexing map used
-/// to access the source (target) of the reshape operation in the generic
-/// operation.
-/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
-/// the `populateFoldReshapeByCollapsingPatterns`.
-void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
-
-/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic operation by linearizing the indexing map used
-/// to access the source (target) of the reshape operation in the generic
-/// operation. The patterns are applied only when the tensor reshape involved is
-/// collapsing (introducing) unit-extent dimensions.
-/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
-/// the `populateFoldReshapeByCollapsingPatterns`.
-void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
- RewritePatternSet &patterns);
-
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
@@ -128,12 +106,6 @@ void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
/// Patterns that are used to bubble up extract slice op above linalg op.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
-/// Patterns to push reshape op towards the end of the graph in order to expose
-/// more fusion opportunities.
-/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
-/// the `populateFoldReshapeByCollapsingPatterns`.
-void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
-
/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3aabac2ba456a..cc0ec0866f842 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -392,263 +392,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
};
} // namespace
-//===---------------------------------------------------------------------===//
-// Methods and patterns that fuse reshape ops with elementwise operations by
-// linearization of indexing maps.
-//===---------------------------------------------------------------------===//
-
-// TODO(ravishankarm): The indexing maps
-// these produce in the general case are detrimental to transformations.
-// These patterns are on deprecation path in favor of using fusion by
-// collapsing, which covers the only legitimate use case of this pattern of
-// folding unit-extent dims.
-
-/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
-/// provided, given the shape of the source tensor that corresponds to the
-/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
-/// are "row-major" ordered logically.
-///
-/// For example:
-///
-/// %0 = op ... : tensor<?x?x4x5xf32>
-/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
-///
-/// and reshape:
-/// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
-/// tensor<?x?x4x5xf32> into tensor<?x?xf32>
-///
-/// would be rewritten into:
-/// %0 = op ... : tensor<?x?x4x5xf32>
-/// with output index_map
-/// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
-template <typename TensorReshapeOp>
-static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
- TensorReshapeOp reshapeOp) {
- constexpr bool isExpanding =
- std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
- ArrayRef<int64_t> sourceShape =
- (isExpanding ? reshapeOp.getResultType().getShape()
- : reshapeOp.getSrcType().getShape());
- SmallVector<AffineExpr> resultExprs;
- ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
- MLIRContext *context = sourceMap.getContext();
-
- // Compute the result exprs based on the reassociation maps.
- for (auto &indices : reshapeOp.getReassociationIndices()) {
- // Assume that they are in-order and contiguous (already checked in
- // verifier).
- assert(!indices.empty());
- SmallVector<int64_t> sizes;
- SmallVector<AffineExpr> dimExprs;
- for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
- sourceExprs.slice(indices[0], indices.size()))) {
- if (std::get<0>(en) == 1)
- continue;
- sizes.push_back(std::get<0>(en));
- dimExprs.push_back(std::get<1>(en));
- }
- AffineExpr linearizedExpr =
- makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
- resultExprs.push_back(linearizedExpr);
- }
- // The new affine map cannot drop unused dimension but some new symbols may
- // have been added. Create a map with at least as many dimensions/symbols as
- // the original affine map.
- int64_t maxDim = -1;
- int64_t maxSym = -1;
- getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
- unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
- unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
- return AffineMap::get(numDims, numSyms, resultExprs, context);
-}
-
-// tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
-// producer). Fusing when operand has higher rank will require use of mods and
-// divs in the indexing maps of the fused op which would make it non-invertible.
-static bool isTensorReshapeOpFoldableByLinearization(
- tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
- if (!asProducer)
- return false;
- return useIndexMap.isPermutation();
-}
-
-// tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
-// consumer).
-static bool
-isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
- AffineMap useIndexMap,
- bool asProducer) {
- if (asProducer)
- return false;
- return useIndexMap.isPermutation();
-}
-
-/// Check if the reshape operation is only expansion into/collapsing of
-/// unit-dimension.
-template <typename TensorReshapeOp>
-static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
- constexpr bool isExpanding =
- std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
- ArrayRef<int64_t> expandedShape =
- (isExpanding ? reshapeOp.getResultType().getShape()
- : reshapeOp.getSrcType().getShape());
- for (auto &indices : reshapeOp.getReassociationIndices()) {
- unsigned numUnitDims = 0;
- for (int64_t position : indices)
- if (expandedShape[position] == 1)
- numUnitDims++;
- if (numUnitDims != indices.size() - 1)
- return false;
- }
- return true;
-}
-
-namespace {
-/// Pattern to fold tensor_expand_shape op with its consumer by using the source
-/// of the reshape op as the operand in the consumer (instead of the result of
-/// the tensor_collapse_shape). The corresponding index map in the consumer
-/// needs to be modified to linearize the folded dimension.
-///
-/// For example,
-///
-/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
-/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
-/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
-/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
-/// -> tensor<?x?x4x?xf32>
-///
-/// can be folded into
-///
-/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
-/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
-/// -> tensor<?x?x4x?xf32>
-template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
-struct FoldProducerReshapeOpByLinearization
- : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- if (!genericOp.hasTensorSemantics())
- return failure();
- SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
- for (const auto &en : llvm::enumerate(inputOperands)) {
- auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
- if (!reshapeOp)
- continue;
-
- if (!isTensorReshapeOpFoldableByLinearization(
- reshapeOp, genericOp.getTiedIndexingMap(en.value()),
- /*asProducer =*/true) ||
- (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
- continue;
-
- // Compute the fused operands list,
- SmallVector<Value> fusedOperands = genericOp.getInputOperands();
- fusedOperands[en.index()] = reshapeOp.src();
- SmallVector<Value> outputOperands = genericOp.getOutputOperands();
- llvm::append_range(fusedOperands, outputOperands);
-
- // Compute indexing_maps for the fused operation. The indexing_maps for
- // the operands of the consumers that arent fused are the same.
- SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
-
- // Compute the indexing map to use for the result of the producer.
- AffineMap modifiedMap =
- linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
- // The modified map cannot have symbols.
- if (modifiedMap.getNumSymbols())
- return failure();
- for (AffineExpr expr : modifiedMap.getResults()) {
- if (!expr.isPureAffine())
- return failure();
- }
- fusedIndexMaps[en.index()] = modifiedMap;
-
- // Further check that the resulting index maps can be fused and
- // inverted. Without this the resultant op is not legal.
- if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
- return rewriter.notifyMatchFailure(
- genericOp, "fused op loop bound computation failed");
- }
-
- rewriter.startRootUpdate(genericOp);
- genericOp->setOperands(fusedOperands);
- genericOp.indexing_mapsAttr(
- rewriter.getAffineMapArrayAttr(fusedIndexMaps));
- rewriter.finalizeRootUpdate(genericOp);
- return success();
- }
- return failure();
- }
-};
-
-/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
-/// producer. The corresponding index map in the consumer needs to be modified
-/// to linearize the folded dimension.
-template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
-struct FoldConsumerReshapeOpByLinearization
- : public OpRewritePattern<TensorReshapeOp> {
- using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
- PatternRewriter &rewriter) const override {
- GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
- if (!producer || !producer.hasTensorSemantics() ||
- producer.getNumOutputs() != 1 ||
- !isTensorReshapeOpFoldableByLinearization(
- reshapeOp,
- producer.getTiedIndexingMap(producer.getOutputOperand(0)),
- /*asProducer =*/false) ||
- (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
- return failure();
- // The indexing_maps for the operands of the fused operation are same as
- // those for the operands of the producer.
- SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
-
- // Compute the indexing map to use for the operand of the producer.
- AffineMap modifiedMap = linearizeCollapsedDims(
- producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
- for (AffineExpr expr : modifiedMap.getResults()) {
- if (!expr.isPureAffine()) {
- return rewriter.notifyMatchFailure(
- producer, "fused op indexing map is not affine");
- }
- }
- fusedIndexMaps.back() = modifiedMap;
-
- // Further check that the resulting index maps can be fused and
- // inverted. Without this the resultant op is not legal.
- if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
- return rewriter.notifyMatchFailure(
- producer, "fused op loop bound computation failed");
- }
-
- Location loc = producer.getLoc();
- SmallVector<Value> inputOperands = producer.getInputOperands();
- Value output = rewriter.create<TensorReshapeOp>(
- loc, producer.getOutputOperand(0)->get(),
- reshapeOp.getReassociationExprs());
- auto fusedOp = rewriter.create<GenericOp>(
- loc, reshapeOp.getResultType(),
- /*inputs=*/inputOperands,
- // TODO: handle outputs.
- /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
- producer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr);
- auto &fusedRegion = fusedOp->getRegion(0);
- rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
- fusedRegion.begin());
- rewriter.replaceOp(reshapeOp, fusedOp->getResults());
- return success();
- }
-};
-} // namespace
-
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse reshape ops with elementwise operations by
// expanding the dimensionality of the elementwise operations.
@@ -1737,174 +1480,6 @@ class FoldWithProducerReshapeOpByCollapsing
};
} // namespace
-//===---------------------------------------------------------------------===//
-// Methods and patterns to convert tensor.expand_shape -> linalg.generic
-// into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
-//===---------------------------------------------------------------------===//
-
-// TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by
-// collapsing that provides a more general functionality. This pattern is very
-// specific to a particular use case. The fusion by collapsing can provide the
-// same control to clients using the control function there.
-
-static SmallVector<ReassociationIndices>
-getReassociationIndices(ArrayRef<AffineMap> maps) {
- SmallVector<ReassociationIndices> reassociation;
- for (AffineMap map : maps) {
- ReassociationIndices indices;
- for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
- unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
- indices.push_back(pos);
- }
- reassociation.push_back(indices);
- }
- return reassociation;
-}
-
-namespace {
-/// Pattern to move rank reducing reshape after an elementwise linalg generic
-/// op. This is useful to expose more fusion opportunities between named ops and
-/// generic ops. This can only be done if there is no broadcast or permuation
-/// within the dimensions we need to merge.
-///
-/// For example,
-///
-/// %0 = tensor.expand_shape %A [[0, 1], [2]]
-/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
-/// %2 = linalg.generic {indexing_maps = [
-/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
-/// affine_map<(d0, d1, d2) -> (d2)>,
-/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
-/// ["parallel", "parallel", "parallel"]} {
-/// } -> tensor<112x112x16xf32>
-///
-/// into
-///
-/// %2 = linalg.generic {indexing_maps = [
-/// affine_map<(d0, d1) -> (d0, d1)>,
-/// affine_map<(d0, d1) -> (d1)>,
-/// affine_map<(d0, d1) -> (d0, d1)>],
-/// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
-/// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
-/// } -> tensor<12544x16xf32>
-/// %3 = tensor.expand_shape %2 [[0, 1], [2]]
-/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
-struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- // Only apply to elementwise linalg on tensor.
- if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
- genericOp.getNumParallelLoops() != genericOp.getNumLoops())
- return failure();
- // Only support identity output maps. It could be extended to permuations if
- // needed.
- if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
- return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
- }))
- return failure();
- int64_t destRank = genericOp.getNumParallelLoops();
- SmallVector<Value> newOperands = genericOp.getInputOperands();
- tensor::ExpandShapeOp reshapeFound;
- // 1. Look for tensor_expand_shape operands and figure out save the
- // dimensions merged.
- SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
- for (const auto &en : llvm::enumerate(inputOperands)) {
- auto reshapeOp =
- en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
- if (!reshapeOp)
- continue;
- // TODO: We could support non-identity map as long as the merged
- // dimensions are still contiguous.
- if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
- continue;
- if (reshapeFound) {
- // Only support a second reshape op if it has the same reassociate maps.
- if (reshapeFound.getReassociationMaps() ==
- reshapeOp.getReassociationMaps())
- newOperands[en.index()] = reshapeOp.src();
- continue;
- }
- reshapeFound = reshapeOp;
- newOperands[en.index()] = reshapeOp.src();
- }
- if (!reshapeFound)
- return failure();
-
- // Calculate the reassociation indices and rassociated reverse map.
- SmallVector<ReassociationIndices> reassociation =
- getReassociationIndices(reshapeFound.getReassociationMaps());
- SmallVector<unsigned> remap(destRank);
- for (auto &indices : llvm::enumerate(reassociation)) {
- for (int64_t index : indices.value()) {
- remap[index] = indices.index();
- }
- }
- // 2. Verify that we can merge the dimensions in the linalg and that we
- // don't need to create new reshapes operands. Inserting new reshape
- // operands would defeat the purpose of the transformation.
- for (const auto &en : llvm::enumerate(inputOperands)) {
- if (en.value()->get() == newOperands[en.index()]) {
- AffineMap map = genericOp.getTiedIndexingMap(en.value());
- for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
- if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
- return failure();
- }
- }
- }
-
- // 3. Calculate the affine map remapping and the reassociation to apply to
- // output tensors.
- SmallVector<AffineMap> newMaps;
- unsigned newRank = reassociation.size();
- for (auto map : genericOp.getIndexingMaps()) {
- SmallVector<AffineExpr> newExprs;
- for (auto expr : map.getResults()) {
- unsigned position = expr.template cast<AffineDimExpr>().getPosition();
- // Skip dimension merged except for the last of the group.
- if (reassociation[remap[position]].back() == position) {
- newExprs.push_back(
- getAffineDimExpr(remap[position], genericOp.getContext()));
- }
- }
- newMaps.push_back(
- AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
- }
-
- // 4. Reshape the output tensors.
- SmallVector<Value> newOutputs;
- SmallVector<Type> newOutputTypes;
- for (auto output : genericOp.outputs()) {
- auto newOutputType = RankedTensorType::get(
- reshapeFound.getSrcType().getShape(),
- output.getType().template cast<RankedTensorType>().getElementType());
- Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
- genericOp->getLoc(), newOutputType, output, reassociation);
- newOutputTypes.push_back(newOutputType);
- newOutputs.push_back(newOutput);
- }
- // 5. Create a new generic op with lowerer rank.
- SmallVector<StringRef> iteratorTypes(newRank,
- getParallelIteratorTypeName());
- auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
- newOperands, newOutputs, newMaps,
- iteratorTypes);
- rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
- newOp.region().begin());
- // 6. Reshape the so that the type matches the uses.
- SmallVector<Value> newResults;
- for (const auto &result : llvm::enumerate(newOp->getResults())) {
- newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
- genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
- result.value(), reassociation));
- }
- rewriter.replaceOp(genericOp, newResults);
- return success();
- }
-};
-} // namespace
-
//===---------------------------------------------------------------------===//
// Methods and patterns that fuse constants with linalg.generic operations.
//===---------------------------------------------------------------------===//
@@ -2093,27 +1668,6 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
}
};
} // namespace
-//===---------------------------------------------------------------------===//
-// Methods that add patterns described in this file to a pattern list.
-//===---------------------------------------------------------------------===//
-
-void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
- RewritePatternSet &patterns) {
- patterns.add<
- FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
- FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
- FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>>(
- patterns.getContext());
-}
-
-void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
- RewritePatternSet &patterns) {
- patterns
- .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
- FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
- FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>>(
- patterns.getContext());
-}
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns,
@@ -2140,28 +1694,10 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
RemoveOutsDependency>(context);
}
-void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
- auto *context = patterns.getContext();
- patterns.add<PushExpandingReshape>(context);
-}
-
//===---------------------------------------------------------------------===//
// Passes
//===---------------------------------------------------------------------===//
-bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
- OpOperand &consumer) {
- if (auto producerCollapseOp =
- dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
- return !isUnitDimExpansionOnly(producerCollapseOp);
- }
- if (auto consumerExpandOp =
- dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
- return !isUnitDimExpansionOnly(consumerExpandOp);
- }
- return true;
-}
-
namespace {
/// Pass that fuses generic ops on tensors. Used only for testing.
@@ -2186,9 +1722,7 @@ struct LinalgElementwiseOpFusionPass
// Add elementwise op fusion patterns.
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
- populateFoldReshapeOpsByExpansionPatterns(
- patterns,
- allowFoldingUnitDimReshapes ? defaultControlFn : skipUnitDimReshape);
+ populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
// Add the sparse tensor rewriting patterns.
populateSparseTensorRewriting(patterns);
@@ -2212,27 +1746,8 @@ struct LinalgElementwiseOpFusionPass
}
};
-/// Pass to test folding of reshape ops with generic ops by linearization.
-struct FoldReshapeOpsByLinearizationPass
- : public LinalgFoldReshapeOpsByLinearizationBase<
- FoldReshapeOpsByLinearizationPass> {
- void runOnOperation() override {
- Operation *op = getOperation();
- RewritePatternSet patterns(op->getContext());
- populateFoldReshapeOpsByLinearizationPatterns(patterns);
- if (allowFoldingUnitDimReshapes) {
- populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
- }
- (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
- }
-};
-
} // namespace
std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
return std::make_unique<LinalgElementwiseOpFusionPass>();
}
-
-std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
- return std::make_unique<FoldReshapeOpsByLinearizationPass>();
-}
diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index ea699d820b610..33489cba431fc 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-collapsing -split-input-file | FileCheck %s
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
@@ -124,30 +124,3 @@ func.func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
// CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>)
// CHECK: tensor.expand_shape %[[OP]]
// CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32>
-
-// -----
-
-func.func @generic_op_index_semantics(%A: tensor<?x16xi64>, %B: tensor<16xi64>, %init: tensor<?x112x16xi64>) -> tensor<?x112x16xi64> {
- %0 = tensor.expand_shape %A [[0, 1], [2]]
- : tensor<?x16xi64> into tensor<?x112x16xi64>
- %2 = linalg.generic {indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0, %B : tensor<?x112x16xi64>, tensor<16xi64>)
- outs(%init : tensor<?x112x16xi64>) {
- ^bb0(%arg1: i64, %arg2: i64, %arg3: i64): // no predecessors
- %index = linalg.index 0 : index
- %1 = arith.index_cast %index : index to i64
- %add = arith.addi %arg1, %1 : i64
- %s = arith.subi %add, %arg2 : i64
- linalg.yield %s : i64
- } -> tensor<?x112x16xi64>
- return %2 : tensor<?x112x16xi64>
-}
-// CHECK: func @generic_op_index_semantics
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x16xi64>
-// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK: %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[RESHAPE]]
-// CHECK: return %[[RESULT]]
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index ebee7e75ac5a1..45e8721278e1c 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -linalg-fuse-elementwise-ops="allow-folding-unit-dim-reshapes=false" -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -linalg-fuse-elementwise-ops="allow-folding-unit-dim-reshapes=true" -split-input-file | FileCheck %s --check-prefix=FOLDUNITDIM
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-expansion -split-input-file | FileCheck %s
+
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
#map2 = affine_map<(d0, d1, d2) -> ()>
@@ -14,7 +14,7 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
indexing_maps = [#map0, #map1, #map2, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0, %arg1, %arg2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32)
- outs(%0 : tensor<?x?x?xf32>) {
+ outs(%arg1 : tensor<?x?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32):
%1 = arith.mulf %arg3, %arg4 : f32
%2 = arith.addf %1, %arg5 : f32
@@ -30,15 +30,15 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
-// CHECK: %[[T0:.+]] = tensor.collapse_shape %[[ARG0]]
-// CHECK-SAME: [0], [1, 2], [3]
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
// CHECK-SAME: [0], [1], [2, 3]
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME: [0], [1], [2, 3]
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[ARG0]], %[[T1]], %[[ARG2]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>, f32)
-// CHECK-SAME: outs(%{{.+}} : tensor<?x?x?x4xf32>)
+// CHECK-SAME: outs(%[[T2]] : tensor<?x?x?x4xf32>)
// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]]
// CHECK-SAME: [0], [1], [2, 3]
// CHECK-SAME: tensor<?x?x?x4xf32> into tensor<?x?x?xf32>
@@ -80,12 +80,14 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
// CHECK-SAME: [0], [1, 2, 3]
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME: [0], [1, 2, 3]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]], %[[T1]], %[[ARG2]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>, f32)
-// CHECK-SAME: outs(%{{.+}} : tensor<?x4x?x5xf32>)
+// CHECK-SAME: outs(%[[T2]] : tensor<?x4x?x5xf32>)
// CHECK: return %[[T3]] : tensor<?x4x?x5xf32>
@@ -121,11 +123,14 @@ func.func @reshape_as_consumer_permutation
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK-SAME: tensor<?x?xf32> into tensor<3x4x?x?xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME: [0, 1], [2], [3, 4, 5]]
+// CHECK-SAME: tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>)
-// CHECK-SAME: outs(%{{.+}} : tensor<?x2x?x3x4x?xf32>)
+// CHECK-SAME: outs(%[[T2]] : tensor<?x2x?x3x4x?xf32>)
// CHECK: return %[[T3]] : tensor<?x2x?x3x4x?xf32>
// -----
@@ -155,14 +160,19 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @generic_op_reshape_consumer_static
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant
+// CHECK-SAME: : tensor<8x33x4xf32>
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [264, 4]
// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1], [2]
// CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32>
-// CHECK: %[[T1:.+]] = linalg.init_tensor [8, 33, 4]
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]]
+// CHECK-SAME: [0, 1], [2]
+// CHECK-SAME: : tensor<264x4xf32> into tensor<8x33x4xf32>
// CHECK: %[[T2:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel"]
-// CHECK-SAME: ins(%[[T0]] : tensor<8x33x4xf32>)
+// CHECK-SAME: ins(%[[T0]], %[[CST]] :
// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
// CHECK: return %[[T2]] : tensor<8x33x4xf32>
@@ -246,7 +256,8 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
}
// Only check the body in the indexed version of the test.
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 5)>
// CHECK: func @indexed_producer_reshape_consumer_fusion
// CHECK: linalg.generic
// CHECK: ^{{.*}}(
@@ -256,11 +267,12 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
-// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX3]], %[[IDX2]], %[[IDX1]])
+// CHECK: %[[T1:.+]] = affine.apply #[[MAP1]](%[[IDX2]], %[[IDX1]])
+// CHECK: %[[T2:.+]] = affine.apply #[[MAP2]](%[[IDX3]], %[[T1]])
// CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
// CHECK: %[[T5:.+]] = arith.index_cast %[[IDX0]]
// CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
-// CHECK: %[[T7:.+]] = arith.index_cast %[[T3]]
+// CHECK: %[[T7:.+]] = arith.index_cast %[[T2]]
// CHECK: %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
// CHECK: linalg.yield %[[T8]]
@@ -295,24 +307,29 @@ func.func @reshape_as_consumer_permutation
return %d : tensor<2x3x4x5x6x7xi32>
}
+// -----
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
-// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
-// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
-// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
-// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 6)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 7)>
// CHECK: func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32>
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [6, 4, 210]
// CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: [0, 1, 2], [3, 4], [5]
// CHECK-DAG: %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
// CHECK-SAME: [0, 1, 2], [3]
-// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
+// CHECK-DAG: %[[T3:.+]] = tensor.expand_shape %[[INIT]]
+// CHECK-SAME: [0, 1], [2], [3, 4, 5]
+// CHECK-SAME: : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
// CHECK: %[[T4:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
-// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>)
+// CHECK-SAME: outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>)
// CHECK: ^{{.+}}(
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32)
@@ -322,15 +339,16 @@ func.func @reshape_as_consumer_permutation
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
// CHECK-DAG: %[[IDX4:.+]] = linalg.index 4 : index
// CHECK-DAG: %[[IDX5:.+]] = linalg.index 5 : index
-// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP8]](%[[IDX1]], %[[IDX0]])
-// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP9]](%[[IDX4]], %[[IDX3]], %[[IDX2]])
-// CHECK-DAG: %[[T7:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
-// CHECK: %[[T8:.+]] = arith.index_cast %[[T5]]
-// CHECK: %[[T9:.+]] = arith.addi %[[T7]], %[[T8]]
-// CHECK: %[[T10:.+]] = arith.index_cast %[[T6]]
-// CHECK: %[[T11:.+]] = arith.addi %[[T9]], %[[T10]]
-// CHECK: %[[T12:.+]] = arith.index_cast %[[IDX5]]
-// CHECK: %[[T13:.+]] = arith.addi %[[T11]], %[[T12]]
+// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP3]](%[[IDX1]], %[[IDX0]])
+// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP4]](%[[IDX3]], %[[IDX2]])
+// CHECK-DAG: %[[T7:.+]] = affine.apply #[[MAP5]](%[[IDX4]], %[[T6]])
+// CHECK-DAG: %[[T8:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
+// CHECK: %[[T9:.+]] = arith.index_cast %[[T5]]
+// CHECK: %[[T10:.+]] = arith.addi %[[T8]], %[[T9]]
+// CHECK: %[[T11:.+]] = arith.index_cast %[[T7]]
+// CHECK: %[[T12:.+]] = arith.addi %[[T10]], %[[T11]]
+// CHECK: %[[T13:.+]] = arith.index_cast %[[IDX5]]
+// CHECK: %[[T14:.+]] = arith.addi %[[T12]], %[[T13]]
// -----
@@ -421,94 +439,18 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x5x?xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME: [0], [1, 2, 3]
+// CHECK-SAME: tensor<?x?xf32> into tensor<?x?x4x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>)
-// CHECK-SAME: outs(%{{.+}} : tensor<?x?x4x5xf32>)
+// CHECK-SAME: outs(%[[T2]] : tensor<?x?x4x5xf32>)
// CHECK: return %[[T3]] : tensor<?x?x4x5xf32>
// -----
-func.func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> {
- %0 = tensor.collapse_shape %arg0 [[0, 1]]
- : tensor<1x5xf32> into tensor<5xf32>
- %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
- %2 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%0 : tensor<5xf32>) outs(%1 : tensor<5x5xf32>) {
- ^bb0(%arg2: f32, %arg3: f32):
- linalg.yield %arg2 : f32
- } -> tensor<5x5xf32>
- return %2 : tensor<5x5xf32>
-}
-// CHECK: func @unit_dim_reshape_expansion
-// CHECK-DAG: tensor.collapse_shape
-// CHECK-DAG: linalg.init_tensor
-// CHECK: linalg.generic
-
-// -----
-
-func.func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> {
- %0 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
- %1 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5x5xf32>) {
- ^bb0(%arg2: f32, %arg3: f32):
- linalg.yield %arg2 : f32
- } -> tensor<5x5xf32>
- %2 = tensor.expand_shape %1 [[0, 1], [2]]
- : tensor<5x5xf32> into tensor<5x1x5xf32>
- return %2 : tensor<5x1x5xf32>
-}
-// CHECK: func @unit_dim_reshape_collapse
-// CHECK: linalg.init_tensor
-// CHECK: linalg.generic
-// CHECK: tensor.expand_shape
-
-// -----
-
-func.func @unit_dim_reshape_expansion_full
- (%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor<?x2x4xf32>)
- -> tensor<?x2x4xf32> {
- %c1 = arith.constant 1 : index
- %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
- : tensor<1x?x1x2x1x4xf32> into tensor<?x2x4xf32>
- %1 = tensor.dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32>
- %2 = linalg.init_tensor [%1, 2, 4] : tensor<?x2x4xf32>
- %3 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0, %arg1 : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
- outs(%2 : tensor<?x2x4xf32>) {
- ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
- %4 = arith.mulf %arg2, %arg3 : f32
- linalg.yield %4 : f32
- } -> tensor<?x2x4xf32>
- return %3 : tensor<?x2x4xf32>
-}
-// CHECK: func @unit_dim_reshape_expansion_full
-// CHECK-DAG: tensor.collapse_shape
-// CHECK-DAG: linalg.init_tensor
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
-
-// FOLDUNITDIM: func @unit_dim_reshape_expansion_full
-// FOLDUNITDIM-SAME: %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32>
-// FOLDUNITDIM-SAME: %[[ARG1:.+]]: tensor<?x2x4xf32>
-// FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG1]]
-// FOLDUNITDIM: linalg.generic
-// FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
-// FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)
-
-// -----
-
func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
@@ -554,7 +496,6 @@ func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x1xi64>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64>
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
-// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?xi64> to tensor<2xi64>
// CHECK: %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[RESHAPE]], %[[CAST]] : tensor<2xi64>, tensor<2xi64>)
+// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor<?xi64>)
// CHECK: return %[[GENERIC]]
diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
deleted file mode 100644
index 089b30694231f..0000000000000
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
+++ /dev/null
@@ -1,287 +0,0 @@
-// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s
-
-// Note: These tests fuse the reshape ops by linearization. This can create
-// indexing maps which are hard to analyse later on. These patterns are useful
-// only if the folded dimensions in the reshape op are unit extent. Tests here
-// are more general for testing purposes, but use of these pattern for non-unit
-// dimensions should be deprecated.
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
- -> tensor<?x?x4x?xi32> {
- %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] :
- tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
- %1 = linalg.generic {
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
- ins(%0 : tensor<?x?x4x?xi32>)
- outs(%0 : tensor<?x?x4x?xi32>) {
- ^bb0(%arg6: i32, %arg7 : i32):
- %idx = linalg.index 0 : index
- %2 = arith.index_cast %idx : index to i32
- %3 = arith.addi %arg6, %2 : i32
- linalg.yield %3 : i32
- } -> tensor<?x?x4x?xi32>
- return %1 : tensor<?x?x4x?xi32>
-}
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK: func @generic_op_reshape_producer_fusion
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xi32>
-// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME: [0], [1, 2], [3]
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]]
-// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x?xi32>)
-// CHECK-SAME: outs(%[[T0]] : tensor<?x?x4x?xi32>)
-// CHECK: %[[IDX:.+]] = linalg.index 0 : index
-// CHECK-NEXT: %[[IDX_CASTED:.+]] = arith.index_cast %[[IDX]] : index to i32
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
- -> tensor<?x?xi32> {
- %0 = linalg.generic {
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
- ins(%arg0 : tensor<?x?x4x5xi32>) outs(%arg0 : tensor<?x?x4x5xi32>) {
- ^bb0(%arg6: i32, %arg7: i32):
- %idx = linalg.index 0 : index
- %2 = arith.index_cast %idx : index to i32
- %3 = arith.addi %arg6, %2 : i32
- linalg.yield %3 : i32
- } -> tensor<?x?x4x5xi32>
- %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] :
- tensor<?x?x4x5xi32> into tensor<?x?xi32>
- return %1 : tensor<?x?xi32>
-}
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-// CHECK: func @generic_op_reshape_consumer_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xi32>
-// CHECK: %[[T0:.+]] = tensor.collapse_shape %[[ARG0]]
-// CHECK-SAME: [0], [1, 2, 3]
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
-// CHECK-SAME: outs(%[[T0]] : tensor<?x?xi32>)
-// CHECK: %[[IDX:.+]] = linalg.index 0 : index
-// CHECK-NEXT: %[[IDX_CASTED:.+]] = arith.index_cast %[[IDX]] : index to i32
-// CHECK-NOT: tensor.collapse_shape
-
-// -----
-
-#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
- %0 = tensor.expand_shape %arg0 [[0], [1, 2]]
- : tensor<3x35xf32> into tensor<3x5x7xf32>
- %1 = linalg.init_tensor [3, 7, 5] : tensor<3x7x5xf32>
- %2 = linalg.generic
- {indexing_maps = [#map2, #map3],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<3x7x5xf32>) {
- ^bb0(%arg2: f32, %arg3 : f32):
- linalg.yield %arg2 : f32
- } -> tensor<3x7x5xf32>
- return %2 : tensor<3x7x5xf32>
-}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK: func @generic_op_021_permultation_reshape_producer_fusion
-// CHECK-NOT: tensor.expand_shape
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-
-// -----
-
-#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
-func.func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
- %0 = tensor.expand_shape %arg0 [[0], [1, 2]]
- : tensor<3x35xf32> into tensor<3x5x7xf32>
- %1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32>
- %2 = linalg.generic
- {indexing_maps = [#map2, #map3],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x7x3xf32>) {
- ^bb0(%arg2: f32, %arg3: f32):
- linalg.yield %arg2 : f32
- } -> tensor<5x7x3xf32>
- return %2 : tensor<5x7x3xf32>
-}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
-// CHECK: func @generic_op_120_permutation_reshape_producer_fusion
-// CHECK-NOT: tensor.expand_shape
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2) -> (d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
- %0 = tensor.expand_shape %arg0 [[0], [1, 2]]
- : tensor<3x35xf32> into tensor<3x5x7xf32>
- %1 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
- %2 = linalg.generic
- {indexing_maps = [#map2, #map3],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x3x7xf32>) {
- ^bb0(%arg2: f32, %arg3: f32):
- linalg.yield %arg2 : f32
- } -> tensor<5x3x7xf32>
- return %2 : tensor<5x3x7xf32>
-}
-
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK: func @generic_op_102_permultation_reshape_producer_fusion
-// CHECK-NOT: tensor.expand_shape
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0)>
-#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
-func.func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> {
- %0 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
- %1 = linalg.generic
- {indexing_maps = [#map0, #map1],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%arg0 : tensor<3x5x7xf32>) outs(%0 : tensor<5x3x7xf32>) {
- ^bb0(%arg2: f32, %arg3 : f32):
- linalg.yield %arg2 : f32
- } -> tensor<5x3x7xf32>
- %2 = tensor.collapse_shape %1 [[0], [1, 2]]
- : tensor<5x3x7xf32> into tensor<5x21xf32>
- return %2 : tensor<5x21xf32>
-}
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
-// CHECK: func @generic_op_102_permultation_reshape_consumer_fusion
-// CHECK-SAME: %[[ARG0:.+]]: tensor<3x5x7xf32>
-// CHECK: %[[T0:.+]] = linalg.init_tensor [5, 3, 7]
-// CHECK: %[[T1:.+]] = tensor.collapse_shape %[[T0]]
-// CHECK-SAME: [0], [1, 2]
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
-// CHECK-SAME: ins(%[[ARG0]] : tensor<3x5x7xf32>)
-// CHECK-SAME: outs(%[[T1]] : tensor<5x21xf32>)
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
- %arg1 : tensor<?x?x?x5xf32>) ->
- tensor<?x?xf32>
-{
- %0 = linalg.generic {
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>)
- outs(%arg0 : tensor<?x?x?x5xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
- %1 = arith.mulf %arg3, %arg4 : f32
- linalg.yield %1 : f32
- } -> tensor<?x?x?x5xf32>
- %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] :
- tensor<?x?x?x5xf32> into tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
-}
-// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x5xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x5xf32>
-// CHECK: %[[NOFUSE:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
-// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[NOFUSE]]
-// CHECK: return %[[RESULT]]
-
-
-// -----
-
-func.func @generic_op_permultation_reshape_consumer_fusion_unused_dim(%arg0 : tensor<6x1xf32>) -> tensor<6xi32> {
- %0 = linalg.init_tensor [6, 1] : tensor<6x1xi32>
- %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0 : tensor<6x1xf32>) outs(%0 : tensor<6x1xi32>) {
- ^bb0(%arg3: f32, %arg4: i32):
- %5 = arith.fptosi %arg3 : f32 to i32
- linalg.yield %5 : i32
- } -> tensor<6x1xi32>
- %6 = tensor.collapse_shape %1 [[0, 1]] : tensor<6x1xi32> into tensor<6xi32>
- return %6 : tensor<6xi32>
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK: func @generic_op_permultation_reshape_consumer_fusion_unused_dim
-// CHECK-SAME: %[[ARG0:.+]]: tensor<6x1xf32>
-// CHECK: %[[T0:.+]] = linalg.init_tensor [6, 1]
-// CHECK: %[[T1:.+]] = tensor.collapse_shape %[[T0]]
-// CHECK-SAME: [0, 1]
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME: ins(%[[ARG0]] : tensor<6x1xf32>)
-// CHECK-SAME: outs(%[[T1]] : tensor<6xi32>)
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)>
-#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
-func.func @permuted_dims_fusion_expand_shape(%arg0 : tensor<3x8x7x240xf32>) -> tensor<4x6x3x8x2x5x7xf32> {
- %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6]]
- : tensor<3x8x7x240xf32> into tensor<3x2x4x7x8x5x6xf32>
- %1 = linalg.init_tensor [4, 6, 3, 8, 2, 5, 7] : tensor<4x6x3x8x2x5x7xf32>
- %2 = linalg.generic {
- indexing_maps = [#map0, #map1],
- iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
- ins(%0 : tensor<3x2x4x7x8x5x6xf32>) outs(%1 : tensor<4x6x3x8x2x5x7xf32>) {
- ^bb0(%arg1 : f32, %arg2 : f32):
- linalg.yield %arg1 : f32
- } -> tensor<4x6x3x8x2x5x7xf32>
- return %2 : tensor<4x6x3x8x2x5x7xf32>
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
-// CHECK: func @permuted_dims_fusion_expand_shape(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<3x8x7x240xf32>)
-// CHECK: %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME: ins(%[[ARG0]] :
-// CHECK: return %[[RESULT]]
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)>
-#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
-func.func @permuted_dims_fusion_collapse_shape(%arg0 : tensor<4x6x3x8x2x5x7xf32>) -> tensor<3x8x7x240xf32> {
- %0 = linalg.init_tensor [3, 2, 4, 7, 8, 5, 6] : tensor<3x2x4x7x8x5x6xf32>
- %1 = linalg.generic {
- indexing_maps = [#map1, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
- ins(%arg0 : tensor<4x6x3x8x2x5x7xf32>) outs(%0 : tensor<3x2x4x7x8x5x6xf32>) {
- ^bb0(%arg1 : f32, %arg2 : f32):
- linalg.yield %arg1 : f32
- } -> tensor<3x2x4x7x8x5x6xf32>
- %2 = tensor.collapse_shape %1 [[0], [1, 2], [3], [4, 5, 6]]
- : tensor<3x2x4x7x8x5x6xf32> into tensor<3x8x7x240xf32>
- return %2 : tensor<3x8x7x240xf32>
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)>
-// CHECK: func @permuted_dims_fusion_collapse_shape(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x3x8x2x5x7xf32>)
-// CHECK: %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME: ins(%[[ARG0]] :
-// CHECK: return %[[RESULT]]
diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir
deleted file mode 100644
index 80826057c6bd3..0000000000000
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir
+++ /dev/null
@@ -1,52 +0,0 @@
-// RUN: mlir-opt -linalg-fold-reshape-ops-by-linearization=allow-folding-unit-dim-reshapes -split-input-file %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-func.func @do_not_fold1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?x1xf32>
-{
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
- %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
- %3 = linalg.generic {
- indexing_maps = [#map, #map, #map],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%2 : tensor<?x?xf32>) {
- ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
- %4 = arith.addf %arg2, %arg3 : f32
- linalg.yield %4 : f32
- } -> tensor<?x?xf32>
- %4 = tensor.expand_shape %3 [[0], [1, 2]] : tensor<?x?xf32> into tensor<?x?x1xf32>
- return %4 : tensor<?x?x1xf32>
-}
-// CHECK-LABEL: func @do_not_fold1
-// CHECK: %[[VAL:.+]] = linalg.generic
-// CHECK: tensor.expand_shape %[[VAL]]
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-func.func @do_not_fold2(%arg0 : tensor<?x?x1xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
-{
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor<?x?x1xf32> into tensor<?x?xf32>
- %1 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
- %2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
- %3 = linalg.init_tensor [%1, %2] : tensor<?x?xf32>
- %4 = linalg.generic {
- indexing_maps = [#map, #map, #map],
- iterator_types = ["parallel", "parallel"]}
- ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%3 : tensor<?x?xf32>) {
- ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
- %4 = arith.addf %arg2, %arg3 : f32
- linalg.yield %4 : f32
- } -> tensor<?x?xf32>
- return %4 : tensor<?x?xf32>
-}
-// CHECK-LABEL: func @do_not_fold2
-// CHECK: %[[VAL:.+]] = tensor.collapse_shape
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[VAL]], %{{.+}} : tensor<?x?xf32>, tensor<?x?xf32>)
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index 211ddcfc3730a..ec36b0fe9d266 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -70,18 +70,18 @@ struct TestLinalgElementwiseFusion
llvm::cl::desc("Test fusion of generic operations."),
llvm::cl::init(false)};
+ Option<bool> fuseWithReshapeByExpansion{
+ *this, "fuse-with-reshape-by-expansion",
+ llvm::cl::desc(
+ "Test fusion of generic operations with reshape by expansion"),
+ llvm::cl::init(false)};
+
Option<bool> controlFuseByExpansion{
*this, "control-fusion-by-expansion",
llvm::cl::desc(
"Test controlling fusion of reshape with generic op by expansion"),
llvm::cl::init(false)};
- Option<bool> pushExpandingReshape{
- *this, "push-expanding-reshape",
- llvm::cl::desc("Test linalg expand_shape -> generic "
- "to generic -> expand_shape pattern"),
- llvm::cl::init(false)};
-
Option<bool> fuseWithReshapeByCollapsing{
*this, "fuse-with-reshape-by-collapsing",
llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
@@ -109,6 +109,17 @@ struct TestLinalgElementwiseFusion
return;
}
+ if (fuseWithReshapeByExpansion) {
+ RewritePatternSet fusionPatterns(context);
+ linalg::populateFoldReshapeOpsByExpansionPatterns(
+ fusionPatterns, [](const OpResult & /*producer*/,
+ OpOperand & /*consumer*/) { return true; });
+ if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(fusionPatterns))))
+ return signalPassFailure();
+ return;
+ }
+
if (controlFuseByExpansion) {
RewritePatternSet fusionPatterns(context);
@@ -128,8 +139,9 @@ struct TestLinalgElementwiseFusion
if (linalgOp && linalgOp.isOutputTensor(&use))
return true;
}
+ return false;
}
- return linalg::skipUnitDimReshape(producer, consumer);
+ return true;
};
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
@@ -139,12 +151,6 @@ struct TestLinalgElementwiseFusion
return;
}
- if (pushExpandingReshape) {
- RewritePatternSet patterns(context);
- linalg::populatePushReshapeOpsPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
- }
-
if (fuseWithReshapeByCollapsing) {
RewritePatternSet patterns(context);
linalg::populateFoldReshapeOpsByCollapsingPatterns(
More information about the Mlir-commits
mailing list