[Mlir-commits] [mlir] 6739993 - [mlir][Linalg] Cleanup the drop unit dims pass in Linalg.
Mahesh Ravishankar
llvmlistbot at llvm.org
Wed Jul 19 10:47:38 PDT 2023
Author: Mahesh Ravishankar
Date: 2023-07-19T17:47:18Z
New Revision: 67399932c767f0a64c83a500dc6f7806c09d9401
URL: https://github.com/llvm/llvm-project/commit/67399932c767f0a64c83a500dc6f7806c09d9401
DIFF: https://github.com/llvm/llvm-project/commit/67399932c767f0a64c83a500dc6f7806c09d9401.diff
LOG: [mlir][Linalg] Cleanup the drop unit dims pass in Linalg.
TL;DR the following API functions have been merged
```
void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns);
void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns);
```
into
```
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns,
ControlDropUnitDims &options);
```
To use the previous functionality use
```
ControlDropUnitDims options;
// By default options.rankReductionStrategy is
// ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape.
populateFoldUnitExtentDimsPatterns(patterns, options);
```
and
```
ControlDropUnitDims options;
options.rankReductionStrategy = ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice
populateFoldUnitExtentDimsPatterns(patterns, options);
```
This pass is quite old and needed to be updated based on the current
approach to transformations in Linalg
- Instead of two patterns, one to just remove loop dimensions that are
unit extent (and using 0 in the indexing maps), and another to drop
the unit-extents in the operand shapes, combine into a single
transformation. This avoid creating an intermediate step with
indexing maps having 0's in the domains exp ressions.
- Expose the core transformation as a utility function and add a
pattern that calls this transformation.
This is a mostly NFC change, apart from the API change and dropping
the patterns/test that only dropped the loops that are unit extents.
Differential Revision: https://reviews.llvm.org/D155518
Added:
mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/lib/Dialect/Linalg/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 1ed867b3a0df75..3093604af63e33 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -28,10 +28,6 @@ def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
let options = [
- Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool",
- /*default=*/"false",
- "Only folds the one-trip loops from Linalg ops on tensors "
- "(for testing purposes only)">,
Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool",
/*default=*/"false",
"Generate rank-reducing slices instead of reassociative reshapes">
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a78dc1e1e571bc..68fce05c0d2221 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -419,6 +419,25 @@ LogicalResult vectorizeOpPrecondition(Operation *op,
using LinalgLoops = SmallVector<Operation *, 4>;
+/// Transformation to drop unit-extent dimensions from `linalg.generic`
+/// operations.
+struct ControlDropUnitDims {
+ enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice };
+
+ RankReductionStrategy rankReductionStrategy =
+ RankReductionStrategy::ReassociativeReshape;
+
+ using ControlFnTy = std::function<SmallVector<unsigned>(Operation *)>;
+ ControlFnTy controlFn = [](Operation *op) {
+ if (auto genericOp = dyn_cast_or_null<GenericOp>(op)) {
+ return llvm::to_vector(llvm::seq<unsigned>(0, genericOp.getNumLoops()));
+ }
+ return SmallVector<unsigned>{};
+ };
+};
+LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+ const ControlDropUnitDims &options);
+
/// Fuse two `linalg.generic` operations that have a producer-consumer
/// relationship captured through `fusedOperand`. The method expects
/// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
@@ -1496,11 +1515,8 @@ void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors via reassociative reshape ops.
-void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns);
-
-/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
-/// tensors via rank-reducing slices.
-void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns);
+void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns,
+ ControlDropUnitDims &options);
/// A pattern that converts init operands to input operands.
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f6e0f27548a20c..40b602ffd4f80b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -156,12 +156,16 @@ void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- linalg::populateFoldUnitExtentDimsViaReshapesPatterns(patterns);
+ linalg::ControlDropUnitDims options;
+ linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
}
void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- linalg::populateFoldUnitExtentDimsViaSlicesPatterns(patterns);
+ linalg::ControlDropUnitDims options;
+ options.rankReductionStrategy =
+ linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
+ linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
}
void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 894036a535302e..b33c75ca94fc0b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -43,196 +43,6 @@ using namespace mlir;
using namespace mlir::linalg;
namespace {
-enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice };
-} // namespace
-
-/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
-/// broadcasting. For example,
-///
-/// ```mlir
-/// #accesses = [
-/// affine_map<(d0, d1) -> (0, d1)>,
-/// affine_map<(d0, d1) -> (d0, 0)>,
-/// affine_map<(d0, d1) -> (d0, d1)>
-/// ]
-///
-/// #trait = {
-/// args_in = 2,
-/// args_out = 1,
-/// indexing_maps = #accesses,
-/// iterator_types = ["parallel", "parallel"],
-/// library_call = "some_external_fn"
-/// }
-///
-/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
-/// tensor<5x5xf32>
-/// {
-/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
-/// tensor<5xf32> into tensor<1x5xf32>
-/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
-/// tensor<5xf32> into tensor<5x1xf32>
-/// %2 = linalg.generic #trait %0, %1 {
-/// ^bb0(%arg2: f32, %arg3: f32):
-/// %3 = arith.addf %arg2, %arg3 : f32
-/// linalg.yield %3 : f32
-/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
-/// return %2 : tensor<5x5xf32>
-/// }
-///
-/// would canonicalize to
-///
-/// ```mlir
-/// #accesses = [
-/// affine_map<(d0, d1) -> (d1)>,
-/// affine_map<(d0, d1) -> (d0)>,
-/// affine_map<(d0, d1) -> (d0, d1)>
-/// ]
-///
-/// #trait = {
-/// args_in = 2,
-/// args_out = 1,
-/// indexing_maps = #accesses,
-/// iterator_types = ["parallel", "parallel"],
-/// library_call = "some_external_fn"
-/// }
-///
-/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
-/// tensor<5x5xf32>
-/// {
-/// %0 = linalg.generic #trait %arg0, %arg1 {
-/// ^bb0(%arg2: f32, %arg3: f32):
-/// %3 = arith.addf %arg2, %arg3 : f32
-/// linalg.yield %3 : f32
-/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
-/// return %0 : tensor<5x5xf32>
-/// }
-
-/// Given dims of the iteration space of a structured op that are known to be
-/// single trip count (`unitDims`), return the indexing maps to use in the
-/// canonicalized op with these dims removed, given the original `indexingMaps`.
-static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
- ArrayRef<AffineMap> indexingMaps,
- MLIRContext *context) {
- if (indexingMaps.empty())
- return nullptr;
- unsigned numIterationDims = indexingMaps.front().getNumDims();
- unsigned numSymbols = indexingMaps.front().getNumSymbols();
-
- // Compute the replacement for each dim expr.
- SmallVector<AffineExpr, 4> dimReplacements;
- dimReplacements.reserve(numIterationDims);
- unsigned numKeptDims = 0;
- for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
- if (unitDims.count(dim))
- dimReplacements.push_back(getAffineConstantExpr(0, context));
- else
- dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
- }
-
- // Symbols remain the same.
- SmallVector<AffineExpr, 4> symReplacements;
- symReplacements.reserve(numSymbols);
- for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
- symReplacements.push_back(getAffineSymbolExpr(symbol, context));
-
- SmallVector<AffineMap, 4> newIndexingMaps;
- newIndexingMaps.reserve(indexingMaps.size());
- for (AffineMap operandMap : indexingMaps) {
- // Expected indexing maps to have no symbols.
- if (operandMap.getNumSymbols())
- return nullptr;
- newIndexingMaps.push_back(simplifyAffineMap(
- operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
- numIterationDims - unitDims.size(),
- numSymbols)));
- }
-
- // Check that the new index maps are invertible. If not, something went
- // wrong, so abort.
- if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
- return nullptr;
- return ArrayAttr::get(context,
- llvm::to_vector<4>(llvm::map_range(
- newIndexingMaps, [](AffineMap map) -> Attribute {
- return AffineMapAttr::get(map);
- })));
-}
-
-/// Update the index accesses of linalg operations having index semantics.
-static void replaceUnitDimIndexOps(GenericOp genericOp,
- const DenseSet<unsigned> &unitDims,
- PatternRewriter &rewriter) {
- for (IndexOp indexOp :
- llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(indexOp);
- if (unitDims.count(indexOp.getDim()) != 0) {
- rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
- } else {
- // Update the dimension of the index operation if needed.
- unsigned droppedDims = llvm::count_if(
- unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
- if (droppedDims != 0)
- rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
- indexOp.getDim() - droppedDims);
- }
- }
-}
-
-namespace {
-/// Pattern to fold unit-trip count loops in GenericOps.
-struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMapsArray();
- if (indexingMaps.empty())
- return failure();
-
- // Check if any of the iteration dimensions are unit-trip count. They will
- // end up being unit-trip count if they are used to index into a unit-dim
- // tensor/memref.
- AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
- if (!invertedMap)
- return failure();
- SmallVector<int64_t> dims = genericOp.getStaticShape();
-
- DenseSet<unsigned> unitDims;
- SmallVector<unsigned, 4> unitDimsReductionLoops;
- ArrayAttr iteratorTypes = genericOp.getIteratorTypes();
- for (const auto &expr : enumerate(invertedMap.getResults())) {
- if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
- if (dims[dimExpr.getPosition()] == 1)
- unitDims.insert(expr.index());
- }
-
- if (unitDims.empty())
- return failure();
-
- // Compute the modified indexing maps.
- MLIRContext *context = rewriter.getContext();
- ArrayAttr newIndexingMapAttr =
- replaceUnitDims(unitDims, indexingMaps, context);
- if (!newIndexingMapAttr)
- return genericOp.emitError("unable to compute modified indexing_maps");
-
- // Compute the iterator types of the modified op by dropping the one-trip
- // count loops.
- SmallVector<Attribute, 4> newIteratorTypes;
- for (const auto &attr : llvm::enumerate(iteratorTypes)) {
- if (!unitDims.count(attr.index()))
- newIteratorTypes.push_back(attr.value());
- }
-
- rewriter.startRootUpdate(genericOp);
- genericOp.setIndexingMapsAttr(newIndexingMapAttr);
- genericOp.setIteratorTypesAttr(ArrayAttr::get(context, newIteratorTypes));
- replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
- rewriter.finalizeRootUpdate(genericOp);
- return success();
- }
-};
-
/// Pattern to move init operands to ins when all the loops are parallel and
/// blockArgument corresponding to init is used in the region. This is a fix-up
/// when unit reduction dimensions are all folded away. In this context, it
@@ -351,243 +161,405 @@ struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
return success();
}
};
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Drop loops that are unit-extents within Linalg operations.
+//===---------------------------------------------------------------------===//
+
+/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
+/// broadcasting. For example,
+///
+/// ```mlir
+/// #accesses = [
+/// affine_map<(d0, d1) -> (0, d1)>,
+/// affine_map<(d0, d1) -> (d0, 0)>,
+/// affine_map<(d0, d1) -> (d0, d1)>
+/// ]
+///
+/// #trait = {
+/// args_in = 2,
+/// args_out = 1,
+/// indexing_maps = #accesses,
+/// iterator_types = ["parallel", "parallel"],
+/// library_call = "some_external_fn"
+/// }
+///
+/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
+/// tensor<5x5xf32>
+/// {
+/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
+/// tensor<5xf32> into tensor<1x5xf32>
+/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
+/// tensor<5xf32> into tensor<5x1xf32>
+/// %2 = linalg.generic #trait %0, %1 {
+/// ^bb0(%arg2: f32, %arg3: f32):
+/// %3 = arith.addf %arg2, %arg3 : f32
+/// linalg.yield %3 : f32
+/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
+/// return %2 : tensor<5x5xf32>
+/// }
+///
+/// would canonicalize to
+///
+/// ```mlir
+/// #accesses = [
+/// affine_map<(d0, d1) -> (d1)>,
+/// affine_map<(d0, d1) -> (d0)>,
+/// affine_map<(d0, d1) -> (d0, d1)>
+/// ]
+///
+/// #trait = {
+/// args_in = 2,
+/// args_out = 1,
+/// indexing_maps = #accesses,
+/// iterator_types = ["parallel", "parallel"],
+/// library_call = "some_external_fn"
+/// }
+///
+/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
+/// tensor<5x5xf32>
+/// {
+/// %0 = linalg.generic #trait %arg0, %arg1 {
+/// ^bb0(%arg2: f32, %arg3: f32):
+/// %3 = arith.addf %arg2, %arg3 : f32
+/// linalg.yield %3 : f32
+/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
+/// return %0 : tensor<5x5xf32>
+/// }
+/// Update the index accesses of linalg operations having index semantics.
+static void
+replaceUnitDimIndexOps(GenericOp genericOp,
+ const llvm::SmallDenseSet<unsigned> &unitDims,
+ RewriterBase &rewriter) {
+ for (IndexOp indexOp :
+ llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(indexOp);
+ if (unitDims.count(indexOp.getDim()) != 0) {
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
+ } else {
+ // Update the dimension of the index operation if needed.
+ unsigned droppedDims = llvm::count_if(
+ unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
+ if (droppedDims != 0)
+ rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
+ indexOp.getDim() - droppedDims);
+ }
+ }
+}
+
+/// Expand the given `value` so that the type matches the type of `origDest`.
+/// The `reassociation` is used when `rankReductionStrategy` is set to
+/// `RankReductionStrategy::ReassociativeReshape`.
+static Value
+expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
+ ArrayRef<ReassociationIndices> reassociation,
+ ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+ // There are no results for memref outputs.
+ auto origResultType = cast<RankedTensorType>(origDest.getType());
+ if (rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+ unsigned rank = origResultType.getRank();
+ SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(rewriter, loc, origDest);
+ SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ return rewriter.createOrFold<tensor::InsertSliceOp>(
+ loc, result, origDest, offsets, sizes, strides);
+ }
+
+ assert(rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
+ "unknown rank reduction strategy");
+ return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
+ reassociation);
+}
+
+/// Collapse the given `value` so that the type matches the type of
+/// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
+/// set to `RankReductionStrategy::ReassociativeReshape`.
+static Value collapseValue(
+ RewriterBase &rewriter, Location loc, Value operand,
+ ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
+ ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+ if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
+ if (rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+ FailureOr<Value> rankReducingExtract =
+ memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
+ targetShape);
+ assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
+ return *rankReducingExtract;
+ }
+
+ assert(
+ rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
+ "unknown rank reduction strategy");
+ MemRefLayoutAttrInterface layout;
+ auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
+ layout, memrefType.getMemorySpace());
+ return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
+ reassociation);
+ }
+ if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
+ if (rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+ FailureOr<Value> rankReducingExtract =
+ tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
+ targetShape);
+ assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
+ return *rankReducingExtract;
+ }
+
+ assert(
+ rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
+ "unknown rank reduction strategy");
+ auto targetType =
+ RankedTensorType::get(targetShape, tensorType.getElementType());
+ return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
+ reassociation);
+ }
+ llvm_unreachable("unsupported operand type");
+}
+
+/// Compute the modified metadata for an operands of operation
+/// whose unit dims are being dropped. Return the new indexing map
+/// to use, the shape of the operand in the replacement op
+/// and the `reassocation` to use to go from original operand shape
+/// to modified operand shape.
struct UnitExtentReplacementInfo {
AffineMap indexMap;
SmallVector<ReassociationIndices> reassociation;
SmallVector<int64_t> targetShape;
};
-} // namespace
-
-/// Utility function for replacing operands/results to a linalg generic
-/// operation with unit-extent dimensions. These can be replaced with
-/// an operand/result with the unit-extent dimension removed. This is only done
-/// if the indexing map used to access that dimension has a
-/// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
-/// Linalg op, and its `indexMap` the utility function returns:
-/// - the new type with dimensions of size 1 removed.
-/// - modified index map that can be used to access the replaced result/operand
-/// - the reassociation that converts from the original tensor type to the
-/// modified tensor type.
-static std::optional<UnitExtentReplacementInfo>
-replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
- MLIRContext *context) {
+static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
+ MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
+ llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
+ ArrayRef<AffineExpr> dimReplacements) {
+ UnitExtentReplacementInfo info;
+ ReassociationIndices reassociationGroup;
+ SmallVector<AffineExpr> newIndexExprs;
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
- ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
+ ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
- SmallVector<AffineExpr> newIndexExprs;
- SmallVector<int64_t> newShape;
- int64_t origRank = genericOp.getRank(opOperand);
- AffineExpr zeroExpr = getAffineConstantExpr(0, context);
- auto isUnitExtent = [&](int64_t dim) -> bool {
- return shape[dim] == 1 && exprs[dim] == zeroExpr;
+ auto isUnitDim = [&](unsigned dim) {
+ if (auto dimExpr = exprs[dim].dyn_cast<AffineDimExpr>()) {
+ unsigned oldPosition = dimExpr.getPosition();
+ return !oldDimsToNewDimsMap.count(oldPosition);
+ }
+ // Handle the other case where the shape is 1, and is accessed using a
+ // constant 0.
+ if (operandShape[dim] == 1) {
+ auto constAffineExpr = exprs[dim].dyn_cast<AffineConstantExpr>();
+ return constAffineExpr && constAffineExpr.getValue() == 0;
+ }
+ return false;
};
- // Early return for memrefs with affine maps to represent that we will always
- // leave them unchanged.
- Type actualType = opOperand->get().getType();
- if (auto memref = dyn_cast<MemRefType>(actualType)) {
- if (!memref.getLayout().isIdentity())
- return std::nullopt;
- }
-
int64_t dim = 0;
- SmallVector<ReassociationIndices> reassociation;
- ReassociationIndices reassociationGroup;
- // Fold dimensions that are unit-extent at the beginning of the tensor.
- while (dim < origRank && isUnitExtent(dim))
+ while (dim < operandShape.size() && isUnitDim(dim))
reassociationGroup.push_back(dim++);
- while (dim < origRank) {
- assert(!isUnitExtent(dim) && "expected non unit-extent");
+ while (dim < operandShape.size()) {
+ assert(!isUnitDim(dim) && "expected non unit-extent");
reassociationGroup.push_back(dim);
- newIndexExprs.push_back(exprs[dim]);
- newShape.push_back(shape[dim]);
+ AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
+ newIndexExprs.push_back(newExpr);
+ info.targetShape.push_back(operandShape[dim]);
++dim;
// Fold all following dimensions that are unit-extent.
- while (dim < origRank && isUnitExtent(dim))
+ while (dim < operandShape.size() && isUnitDim(dim)) {
reassociationGroup.push_back(dim++);
- reassociation.push_back(reassociationGroup);
+ }
+ info.reassociation.push_back(reassociationGroup);
reassociationGroup.clear();
}
-
- // Return if the rank was not reduced.
- if (origRank == static_cast<int64_t>(newShape.size()))
- return std::nullopt;
-
- UnitExtentReplacementInfo info = {
- /*indexMap=*/AffineMap::get(indexingMap.getNumDims(),
- indexingMap.getNumSymbols(), newIndexExprs,
- context),
- /*reassociation=*/reassociation, /*targetShape=*/newShape};
+ info.indexMap =
+ AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(),
+ newIndexExprs, context);
return info;
}
-namespace {
+LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+ const ControlDropUnitDims &options) {
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ if (indexingMaps.empty())
+ return failure();
+
+ // 1. Check if any of the iteration dimensions are unit-trip count. They will
+ // end up being unit-trip count if they are used to index into a unit-dim
+ // tensor/memref.
+ AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
+ if (!invertedMap) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "invalid indexing maps for operation");
+ }
+ SmallVector<int64_t> dims = genericOp.getStaticShape();
-/// Pattern to replace tensor/buffer operands/results that are unit extents.
-struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
- ReplaceUnitExtents(MLIRContext *ctx,
- RankReductionStrategy rankReductionStrategy)
- : OpRewritePattern<GenericOp>(ctx),
- rankReductionStrategy(rankReductionStrategy) {}
-
- // Expand the given value.
- Value expandValue(Value result, Value origOutput,
- ArrayRef<ReassociationIndices> reassociation, Location loc,
- PatternRewriter &rewriter) const {
- // There are no results for memref outputs.
- auto origResultType = cast<RankedTensorType>(origOutput.getType());
- if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
- unsigned rank = origResultType.getRank();
- SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
- SmallVector<OpFoldResult> sizes =
- tensor::getMixedSizes(rewriter, loc, origOutput);
- SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
- return rewriter.createOrFold<tensor::InsertSliceOp>(
- loc, result, origOutput, offsets, sizes, strides);
+ // 1a. Get the allowed list of dimensions to drop from the `options`.
+ SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
+ if (allowedUnitDims.empty()) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "control function returns no allowed unit dims to prune");
+ }
+ llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
+ allowedUnitDims.end());
+ llvm::SmallDenseSet<unsigned> unitDims;
+ ArrayAttr iteratorTypes = genericOp.getIteratorTypes();
+ for (const auto &expr : enumerate(invertedMap.getResults())) {
+ if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) {
+ if (dims[dimExpr.getPosition()] == 1 &&
+ unitDimsFilter.count(expr.index()))
+ unitDims.insert(expr.index());
}
+ }
- assert(rankReductionStrategy ==
- RankReductionStrategy::ReassociativeReshape &&
- "unknown rank reduction strategy");
- return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
- reassociation);
+ // 2. Compute the iterator types of the modified op by dropping the one-trip
+ // count loops.
+ SmallVector<utils::IteratorType> newIteratorTypes;
+ llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
+ SmallVector<AffineExpr> dimReplacements;
+ unsigned newDims = 0;
+ for (auto [index, attr] :
+ llvm::enumerate(genericOp.getIteratorTypesArray())) {
+ if (unitDims.count(index)) {
+ dimReplacements.push_back(
+ getAffineConstantExpr(0, rewriter.getContext()));
+ } else {
+ newIteratorTypes.push_back(attr);
+ oldDimToNewDimMap[index] = newDims;
+ dimReplacements.push_back(
+ getAffineDimExpr(newDims, rewriter.getContext()));
+ newDims++;
+ }
}
- // Collapse the given value.
- Value collapseValue(Value operand, ArrayRef<int64_t> targetShape,
- ArrayRef<ReassociationIndices> reassociation,
- Location loc, PatternRewriter &rewriter) const {
- if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
- if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
- FailureOr<Value> rankReducingExtract =
- memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
- targetShape);
- assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
- return *rankReducingExtract;
- }
-
- assert(rankReductionStrategy ==
- RankReductionStrategy::ReassociativeReshape &&
- "unknown rank reduction strategy");
- MemRefLayoutAttrInterface layout;
- auto targetType =
- MemRefType::get(targetShape, memrefType.getElementType(), layout,
- memrefType.getMemorySpace());
- return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
- reassociation);
+ // 3. For each of the operands, find the
+ // - modified affine map to use.
+ // - shape of the operands after the unit-dims are dropped.
+ // - the reassociation indices used to convert from the original
+ // operand type to modified operand (needed only when using reshapes
+ // for rank reduction strategy)
+ // Note that the indexing maps might need changing even if there are no
+ // unit dimensions that are dropped to handle cases where `0` is used to
+ // access a unit-extent tensor. Consider moving this out of this specific
+ // transformation as a stand-alone transformation. Kept here right now due
+ // to legacy.
+ SmallVector<AffineMap> newIndexingMaps;
+ SmallVector<SmallVector<ReassociationIndices>> reassociations;
+ SmallVector<SmallVector<int64_t>> targetShapes;
+ SmallVector<bool> collapsed;
+ auto hasCollapsibleType = [](OpOperand &operand) {
+ Type operandType = operand.get().getType();
+ if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
+ return memrefOperandType.getLayout().isIdentity();
+ } else if (auto tensorOperandType =
+ dyn_cast<RankedTensorType>(operandType)) {
+ return tensorOperandType.getEncoding() == nullptr;
}
- if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
- if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
- FailureOr<Value> rankReducingExtract =
- tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
- targetShape);
- assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
- return *rankReducingExtract;
- }
-
- assert(rankReductionStrategy ==
- RankReductionStrategy::ReassociativeReshape &&
- "unknown rank reduction strategy");
- auto targetType =
- RankedTensorType::get(targetShape, tensorType.getElementType());
- return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
- reassociation);
+ return false;
+ };
+ for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
+ ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
+ if (!hasCollapsibleType(opOperand)) {
+ AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
+ dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
+ newIndexingMaps.push_back(newIndexingMap);
+ targetShapes.push_back(llvm::to_vector(shape));
+ collapsed.push_back(false);
+ reassociations.push_back({});
+ continue;
}
- llvm_unreachable("unsupported operand type");
+ auto replacementInfo = dropUnitExtentFromOperandMetadata(
+ rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
+ dimReplacements);
+ reassociations.push_back(replacementInfo.reassociation);
+ newIndexingMaps.push_back(replacementInfo.indexMap);
+ targetShapes.push_back(replacementInfo.targetShape);
+ collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
+ indexingMap.getNumResults()));
}
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- // Skip the pattern if the op has any tensor with special encoding.
- if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
- auto tensorType = dyn_cast<RankedTensorType>(type);
- return tensorType && tensorType.getEncoding() != nullptr;
- }))
- return failure();
- MLIRContext *context = rewriter.getContext();
- Location loc = genericOp.getLoc();
- SmallVector<Value> oldOutputs(genericOp.getOutputs().begin(),
- genericOp.getOutputs().end());
-
- SmallVector<AffineMap> newIndexingMaps;
- SmallVector<SmallVector<ReassociationIndices>> reassociations;
- SmallVector<SmallVector<int64_t>> targetShapes;
- SmallVector<bool> collapsed;
- for (OpOperand &opOperand : genericOp->getOpOperands()) {
- auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context);
- if (replacementInfo) {
- reassociations.push_back(replacementInfo->reassociation);
- newIndexingMaps.push_back(replacementInfo->indexMap);
- targetShapes.push_back(replacementInfo->targetShape);
- collapsed.push_back(true);
- } else {
- // If replaceUnitExtents cannot handle this case (or no unit dim was
- // removed), maintain the same type, indexing map, and create a set of
- // mappings representing an identity matrix.
- newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand));
- reassociations.emplace_back();
- targetShapes.emplace_back();
- collapsed.push_back(false);
- }
+ // Abort if the indexing maps of the result operation are not invertible
+ // (i.e. not legal) or if no dimension was reduced.
+ if (newIndexingMaps == indexingMaps ||
+ !inversePermutation(concatAffineMaps(newIndexingMaps)))
+ return failure();
+
+ Location loc = genericOp.getLoc();
+ // 4. For each of the operands, collapse the operand to convert
+ // from original shape to shape in the modified operation if needed,
+ // either through use of reshapes or rank-reducing slices as
+ // specified in `options`.
+ SmallVector<Value> newOperands;
+ for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ int64_t idx = opOperand.getOperandNumber();
+ if (!collapsed[idx]) {
+ newOperands.push_back(opOperand.get());
+ continue;
}
+ newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
+ targetShapes[idx], reassociations[idx],
+ options.rankReductionStrategy));
+ }
- // Abort if the indexing maps of the result operation are not invertible
- // (i.e. not legal) or if no dimension was reduced.
- if (!llvm::any_of(collapsed, [](bool c) { return c; }) ||
- !inversePermutation(concatAffineMaps(newIndexingMaps)))
- return failure();
-
- // Insert rank reductions.
- SmallVector<Value> newOperands;
- for (OpOperand &opOperand : genericOp->getOpOperands()) {
- int64_t idx = opOperand.getOperandNumber();
- if (!collapsed[idx]) {
- newOperands.push_back(opOperand.get());
- continue;
- }
- newOperands.push_back(collapseValue(opOperand.get(), targetShapes[idx],
- reassociations[idx], loc, rewriter));
+ // 5. Create the `linalg.generic` operation with the new operands,
+ // indexing maps, iterator types and result types.
+ ArrayRef<Value> newInputs =
+ ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
+ ArrayRef<Value> newOutputs =
+ ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
+ SmallVector<Type> resultTypes;
+ resultTypes.reserve(genericOp.getNumResults());
+ for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
+ resultTypes.push_back(newOutputs[i].getType());
+ GenericOp replacementOp =
+ rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
+ newIndexingMaps, newIteratorTypes);
+ rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
+ replacementOp.getRegion().begin());
+ // 5a. Replace `linalg.index` operations that refer to the dropped unit
+ // dimensions.
+ replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
+
+ // 6. If any result type changes, insert a reshape/slice to convert from the
+ // original
+ // type to the new type.
+ SmallVector<Value> resultReplacements;
+ for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
+ unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
+ Value origDest = genericOp.getDpsInitOperand(index)->get();
+ if (!collapsed[opOperandIndex]) {
+ resultReplacements.push_back(result);
+ continue;
}
+ resultReplacements.push_back(expandValue(rewriter, loc, result, origDest,
+ reassociations[opOperandIndex],
+ options.rankReductionStrategy));
+ }
- // If any result type changes, insert a reshape to convert from the original
- // type to the new type.
- ArrayRef<Value> newInputs =
- ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
- ArrayRef<Value> newOutputs =
- ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
- SmallVector<Type> resultTypes;
- resultTypes.reserve(genericOp.getNumResults());
- for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
- resultTypes.push_back(newOutputs[i].getType());
- GenericOp replacementOp = rewriter.create<GenericOp>(
- loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
- genericOp.getIteratorTypesArray());
- rewriter.inlineRegionBefore(genericOp.getRegion(),
- replacementOp.getRegion(),
- replacementOp.getRegion().begin());
-
- // If any result tensor has a modified shape, then add reshape to recover
- // the original shape.
- SmallVector<Value> resultReplacements;
- for (const auto &result : llvm::enumerate(replacementOp.getResults())) {
- unsigned index = result.index() + replacementOp.getNumDpsInputs();
- Value origOutput = oldOutputs[result.index()];
- if (!collapsed[result.index() + genericOp.getNumDpsInputs()]) {
- resultReplacements.push_back(result.value());
- continue;
- }
- resultReplacements.push_back(expandValue(
- result.value(), origOutput, reassociations[index], loc, rewriter));
- }
+ rewriter.replaceOp(genericOp, resultReplacements);
+ return success();
+}
- rewriter.replaceOp(genericOp, resultReplacements);
- return success();
+namespace {
+struct DropUnitDims : public OpRewritePattern<GenericOp> {
+ DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), options(std::move(options)) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ return dropUnitDims(rewriter, genericOp, options);
}
private:
- RankReductionStrategy rankReductionStrategy;
+ ControlDropUnitDims options;
};
} // namespace
@@ -641,8 +613,8 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
tensor::CollapseShapeOp reshapedSource;
{
OpBuilder::InsertionGuard g(rewriter);
- // The only
diff erence between InsertSliceOp and ParallelInsertSliceOp is
- // the insertion point is just before the ParallelCombiningOp in the
+ // The only
diff erence between InsertSliceOp and ParallelInsertSliceOp
+ // is the insertion point is just before the ParallelCombiningOp in the
// parallel case.
if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
@@ -660,13 +632,13 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
-void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns(
- RewritePatternSet &patterns) {
+static void
+populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
+ ControlDropUnitDims &options) {
auto *context = patterns.getContext();
- patterns.add<ReplaceUnitExtents>(context,
- RankReductionStrategy::ReassociativeReshape);
+ patterns.add<DropUnitDims>(context, options);
// TODO: Patterns unrelated to unit dim folding should be factored out.
- patterns.add<FoldUnitDimLoops, RankReducedExtractSliceOp,
+ patterns.add<RankReducedExtractSliceOp,
RankReducedInsertSliceOp<tensor::InsertSliceOp>,
RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
context);
@@ -679,12 +651,13 @@ void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns(
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
-void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns(
- RewritePatternSet &patterns) {
+static void
+populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
+ ControlDropUnitDims &options) {
auto *context = patterns.getContext();
- patterns.add<ReplaceUnitExtents>(context,
- RankReductionStrategy::ExtractInsertSlice);
- patterns.add<FoldUnitDimLoops>(context);
+ options.rankReductionStrategy =
+ ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
+ patterns.add<DropUnitDims>(context, options);
// TODO: Patterns unrelated to unit dim folding should be factored out.
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
@@ -693,6 +666,18 @@ void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns(
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
+void mlir::linalg::populateFoldUnitExtentDimsPatterns(
+ RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
+ if (options.rankReductionStrategy ==
+ linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+ populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options);
+ } else if (options.rankReductionStrategy ==
+ linalg::ControlDropUnitDims::RankReductionStrategy::
+ ReassociativeReshape) {
+ populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options);
+ }
+}
+
void mlir::linalg::populateMoveInitOperandsToInputPattern(
RewritePatternSet &patterns) {
patterns.add<MoveInitOperandsToInput>(patterns.getContext());
@@ -706,15 +691,13 @@ struct LinalgFoldUnitExtentDimsPass
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
- if (foldOneTripLoopsOnly) {
- patterns.add<FoldUnitDimLoops, MoveInitOperandsToInput>(context);
- } else if (useRankReducingSlices) {
- populateFoldUnitExtentDimsViaSlicesPatterns(patterns);
- populateMoveInitOperandsToInputPattern(patterns);
- } else {
- populateFoldUnitExtentDimsViaReshapesPatterns(patterns);
- populateMoveInitOperandsToInputPattern(patterns);
+ ControlDropUnitDims options;
+ if (useRankReducingSlices) {
+ options.rankReductionStrategy = linalg::ControlDropUnitDims::
+ RankReductionStrategy::ExtractInsertSlice;
}
+ linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
+ populateMoveInitOperandsToInputPattern(patterns);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index b48a1d6e0cf0b4..88659f8628ae70 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -59,24 +59,24 @@ func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: t
library_call = "some_external_func"
}
-func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
+func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, %shape: tensor<1x1x?x1x1xf32>) -> tensor<1x1x?x1x1xf32> {
%0 = linalg.generic #trait
ins(%arg0, %arg1 : tensor<1x1x1xf32>, f32)
- outs(%shape : tensor<?x1x?x1x?xf32>) {
+ outs(%shape : tensor<1x1x?x1x1xf32>) {
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
linalg.yield %arg3 : f32
- } -> tensor<?x1x?x1x?xf32>
- return %0 : tensor<?x1x?x1x?xf32>
+ } -> tensor<1x1x?x1x1xf32>
+ return %0 : tensor<1x1x?x1x1xf32>
}
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, d0, 0)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @drop_one_trip_loops_all_ones
// CHECK: tensor.collapse_shape %{{.*}} []
-// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = ["parallel"]
-// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]]
// -----
@@ -922,4 +922,3 @@ func.func @drop_all_loops(%arg0 : memref<1x1xf32, 3>) -> memref<1x1xf32, 3>
// CHECK-SLICES-LABEL: func @drop_all_loops
// CHECK-SLICES: memref.subview %{{.*}}[0, 0] [1, 1] [1, 1] : memref<1x1xf32, 3> to memref<f32, strided<[]>, 3>
// CHECK-SLICES: linalg.generic{{.*}}memref<f32, strided<[]>, 3>
-
diff --git a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
deleted file mode 100644
index 2f265a72fd7bf7..00000000000000
--- a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
+++ /dev/null
@@ -1,110 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(linalg-fold-unit-extent-dims{fold-one-trip-loops-only}))" | FileCheck %s
-
-#accesses = [
- affine_map<(i, j, k, l, m) -> (i, k, m)>,
- affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
-]
-
-#trait = {
- iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
- indexing_maps = #accesses,
- library_call = "some_external_func"
-}
-
-func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
-{
- %0 = linalg.generic #trait
- ins(%arg0 : tensor<?x1x?xf32>)
- outs(%shape : tensor<?x1x?x1x?xf32>) {
- ^bb0(%arg1 : f32, %arg2 : f32) :
- linalg.yield %arg1 : f32
- } -> tensor<?x1x?x1x?xf32>
- return %0 : tensor<?x1x?x1x?xf32>
-}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d1, 0, d2)>
-// CHECK-LABEL: func @drop_one_trip_loops
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
-
-// -----
-
-#map0 = affine_map<(i, j) -> (i, j)>
-#access = [#map0, #map0]
-#trait = {
- iterator_types = ["parallel", "parallel"],
- indexing_maps = #access,
- library_call = "some_external_func"
-}
-
-func.func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
-{
- %0 = linalg.generic #trait
- ins(%arg0 : tensor<1x1xf32>)
- outs(%arg0 : tensor<1x1xf32>) {
- ^bb0(%arg1: f32, %arg2: f32) :
- linalg.yield %arg1 : f32
- } -> tensor<1x1xf32>
- return %0 : tensor<1x1xf32>
-}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)>
-// CHECK-LABEL: func @drop_all_loops
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
-// CHECK-SAME: iterator_types = []
-
-// -----
-
-#map0 = affine_map<(i, j) -> (i, j)>
-#access = [#map0, #map0]
-#trait = {
- iterator_types = ["parallel", "parallel"],
- indexing_maps = #access,
- library_call = "some_external_func"
-}
-
-func.func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>)
-{
- linalg.generic #trait
- ins(%arg0 : memref<1x1xf32>)
- outs(%arg1 : memref<1x1xf32>) {
- ^bb0(%arg2: f32, %arg3 : f32) :
- linalg.yield %arg2 : f32
- }
- return
-}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)>
-// CHECK-LABEL: func @drop_all_loops
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
-// CHECK-SAME: iterator_types = []
-
-// -----
-
-#accesses = [
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d1)>
-]
-
-#trait = {
- indexing_maps = #accesses,
- iterator_types = ["parallel", "parallel"],
- library_call = "some_external_fn"
-}
-
-func.func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor<5xf32>) -> tensor<5xf32> {
- %0 = linalg.generic #trait
- ins(%arg0 : tensor<1x5xf32>)
- outs(%shape : tensor<5xf32>) {
- ^bb0(%arg2: f32, %arg3: f32):
- linalg.yield %arg2 : f32
- } -> tensor<5xf32>
- return %0 : tensor<5xf32>
-}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (0, d0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func @leading_dim_1_canonicalization
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-// CHECK-SAME: iterator_types = ["parallel"]
diff --git a/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
new file mode 100644
index 00000000000000..35eeffc1f99532
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/test-drop-unit-dims.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt -test-linalg-drop-unit-dims --split-input-file %s | FileCheck %s
+
+// Drop only the outermost unit dimension (controlled using a control function)
+func.func @drop_outermost_unit_dims(%arg0: tensor<1x1x42xf32>) -> tensor<1x1x42xf32> {
+ %0 = tensor.empty() : tensor<1x1x42xf32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<1x1x42xf32>) outs(%0 : tensor<1x1x42xf32>) {
+ ^bb0(%b0: f32, %b1 : f32):
+ %2 = arith.addf %b0, %b1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<1x1x42xf32>
+ return %1 : tensor<1x1x42xf32>
+}
+// CHECK-LABEL: func @drop_outermost_unit_dims
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x42xf32>
+// CHECK: %[[OUTS:.+]] = tensor.empty()
+// CHECK: %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2]{{\]}}
+// CHECK: %[[OUTS_RESHAPE:.+]] = tensor.collapse_shape %[[OUTS]] {{\[}}[0, 1], [2]{{\]}}
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0_RESHAPE]] :
+// CHECK-SAME: outs(%[[OUTS_RESHAPE]] :
+// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1], [2]{{\]}}
+// CHECK: return %[[EXPAND_SHAPE]]
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 03aefc8d7117e0..b28f2b3564662a 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -2,6 +2,7 @@
add_mlir_library(MLIRLinalgTestPasses
TestDataLayoutPropagation.cpp
TestLinalgDecomposeOps.cpp
+ TestLinalgDropUnitDims.cpp
TestLinalgElementwiseFusion.cpp
TestLinalgFusionTransforms.cpp
TestLinalgTransforms.cpp
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
new file mode 100644
index 00000000000000..a3a6a49d64b003
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
@@ -0,0 +1,73 @@
+//===- TestLinalgDropUnitDims.cpp - Test Linalg drop unit dims -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing the transformation to drop unit
+// extent dimensions from `linalg.generic` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
+ linalg::GenericOp genericOp) {
+ linalg::ControlDropUnitDims options;
+ options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
+ return linalg::dropUnitDims(rewriter, genericOp, options);
+}
+
+struct TestLinalgDropUnitDims
+ : public PassWrapper<TestLinalgDropUnitDims, OperationPass<func::FuncOp>> {
+
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDropUnitDims)
+
+ TestLinalgDropUnitDims() = default;
+ TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass)
+ : PassWrapper(pass) {}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect>();
+ }
+
+ StringRef getArgument() const final { return "test-linalg-drop-unit-dims"; }
+
+ StringRef getDescriptions() const {
+ return "Test transformation to drop unit-extent dims from Linalg "
+ "operations";
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &this->getContext();
+ func::FuncOp funcOp = this->getOperation();
+ IRRewriter rewriter(context);
+ SmallVector<linalg::GenericOp> genericOps;
+ funcOp.walk(
+ [&](linalg::GenericOp genericOp) { genericOps.push_back(genericOp); });
+
+ for (auto genericOp : genericOps) {
+ rewriter.setInsertionPoint(genericOp);
+ dropOutermostUnitDims(rewriter, genericOp);
+ }
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLinalgDropUnitDims() {
+ PassRegistration<TestLinalgDropUnitDims>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e4ba8ab36393bc..08bb6445fb0bc8 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -100,6 +100,7 @@ void registerTestGenericIRVisitorsInterruptPass();
void registerTestInterfaces();
void registerTestLastModifiedPass();
void registerTestLinalgDecomposeOps();
+void registerTestLinalgDropUnitDims();
void registerTestLinalgElementwiseFusion();
void registerTestLinalgGreedyFusion();
void registerTestLinalgTransforms();
@@ -222,6 +223,7 @@ void registerTestPasses() {
mlir::test::registerTestInterfaces();
mlir::test::registerTestLastModifiedPass();
mlir::test::registerTestLinalgDecomposeOps();
+ mlir::test::registerTestLinalgDropUnitDims();
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestLinalgGreedyFusion();
mlir::test::registerTestLinalgTransforms();
More information about the Mlir-commits
mailing list