[Mlir-commits] [mlir] e07149c - [mlir][linalg] Add option to generate rank-reducing slices in DropUnitDims
Matthias Springer
llvmlistbot at llvm.org
Wed Dec 14 05:10:15 PST 2022
Author: Matthias Springer
Date: 2022-12-14T14:10:04+01:00
New Revision: e07149c91f4bb43c1413bc1fbe19dc6eff2fcde6
URL: https://github.com/llvm/llvm-project/commit/e07149c91f4bb43c1413bc1fbe19dc6eff2fcde6
DIFF: https://github.com/llvm/llvm-project/commit/e07149c91f4bb43c1413bc1fbe19dc6eff2fcde6.diff
LOG: [mlir][linalg] Add option to generate rank-reducing slices in DropUnitDims
This change extends the `ReplaceUnitExtents` pattern so that users can choose between of two strategies for generating rank reductions:
* CollapseShapeOp / ExpandShapeOp (was already implemented but code was cleaned up; default strategy)
* rank-reducing ExtractSliceOp / InsertSliceOp
Also add helper functions to the memref dialect that we already have on the tensor dialect: `getMixedSizes`, `createCanonicalRankReducingSubViewOp`, `rankReduceIfNeeded`.
We are using ReassociationIndices instead of ReassoicationExprs in many other places and this makes the code easier to read. Also adding a new test case (that also passed before).
Differential Revision: https://reviews.llvm.org/D139947
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 550feef376db5..05397c6775d35 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -31,7 +31,10 @@ def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool",
/*default=*/"false",
"Only folds the one-trip loops from Linalg ops on tensors "
- "(for testing purposes only)">
+ "(for testing purposes only)">,
+ Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool",
+ /*default=*/"false",
+ "Generate rank-reducing slices instead of reassociative reshapes">
];
let dependentDialects = [
"linalg::LinalgDialect", "AffineDialect", "memref::MemRefDialect"
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7b3e0726effc0..13a7e6f3f2aec 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -129,8 +129,12 @@ void populateFuseTensorPadWithProducerLinalgOpPatterns(
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
-/// tensors.
-void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
+/// tensors via reassociative reshape ops.
+void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns);
+
+/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
+/// tensors via rank-reducing slices.
+void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns);
/// Patterns that are used to inline constant operands into linalg generic ops.
void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 9932f36e5f1bf..bcf86b07db3b4 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -54,6 +54,16 @@ Type getTensorTypeFromMemRefType(Type type);
/// single deallocate if it exists or nullptr.
Optional<Operation *> findDealloc(Value allocValue);
+/// Return the dimensions of the given memref value.
+SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc,
+ Value value);
+
+/// Create a rank-reducing SubViewOp @[0 .. 0] with strides [1 .. 1] and
+/// appropriate sizes (i.e. `memref.getSizes()`) to reduce the rank of `memref`
+/// to that of `targetShape`.
+Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc,
+ Value memref,
+ ArrayRef<int64_t> targetShape);
} // namespace memref
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4a567b40e2e5e..5233badf0bcc1 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1954,6 +1954,15 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
/// Return the dimensions of the source type that are dropped when
/// the result is rank-reduced.
llvm::SmallBitVector getDroppedDims();
+
+ /// Given a `value`, asserted to be of MemRefType, build a SubViewOp that
+ /// results in a rank reduction to the desired memref shape and return the
+ /// new value created.
+ /// If the shape of `value` is already the `desiredShape`, just return
+ /// `value`.
+ /// If the shape of `value` cannot be rank-reduced to `desiredShape`, fail.
+ static FailureOr<Value> rankReduceIfNeeded(
+ OpBuilder &b, Location loc, Value value, ArrayRef<int64_t> desiredShape);
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c0bf57cb705a6..ab288c0bb5f1a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -42,6 +42,10 @@ namespace mlir {
using namespace mlir;
using namespace mlir::linalg;
+namespace {
+enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice };
+} // namespace
+
/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
/// broadcasting. For example,
///
@@ -349,9 +353,9 @@ struct AddInitOperandsToInput : public OpRewritePattern<GenericOp> {
};
struct UnitExtentReplacementInfo {
- Type type;
AffineMap indexMap;
- ArrayAttr reassociation;
+ SmallVector<ReassociationIndices> reassociation;
+ SmallVector<int64_t> targetShape;
};
} // namespace
@@ -371,8 +375,6 @@ replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
- SmallVector<AffineExpr> reassociations;
- SmallVector<Attribute> reassociationMaps;
SmallVector<AffineExpr> newIndexExprs;
SmallVector<int64_t> newShape;
@@ -391,99 +393,110 @@ replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
}
int64_t dim = 0;
+ SmallVector<ReassociationIndices> reassociation;
+ ReassociationIndices reassociationGroup;
// Fold dimensions that are unit-extent at the beginning of the tensor.
while (dim < origRank && isUnitExtent(dim))
- reassociations.push_back(getAffineDimExpr(dim++, context));
+ reassociationGroup.push_back(dim++);
while (dim < origRank) {
- reassociations.push_back(getAffineDimExpr(dim, context));
+ assert(!isUnitExtent(dim) && "expected non unit-extent");
+ reassociationGroup.push_back(dim);
newIndexExprs.push_back(exprs[dim]);
newShape.push_back(shape[dim]);
- // Fold all following dimensions that are unit-extent.
- while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
- ++dim;
- reassociations.push_back(getAffineDimExpr(dim, context));
- }
- reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
- origRank, /*symbolCount = */ 0, reassociations, context)));
- reassociations.clear();
++dim;
+ // Fold all following dimensions that are unit-extent.
+ while (dim < origRank && isUnitExtent(dim))
+ reassociationGroup.push_back(dim++);
+ reassociation.push_back(reassociationGroup);
+ reassociationGroup.clear();
}
- // Compute the tensor or scalar replacement type.
- Type elementType = getElementTypeOrSelf(opOperand->get());
- Type replacementType;
- if (elementType == opOperand->get().getType()) {
- replacementType = elementType;
- } else if (actualType.isa<RankedTensorType>()) {
- replacementType = RankedTensorType::get(newShape, elementType);
- } else {
- auto memrefType = actualType.cast<MemRefType>();
- replacementType = MemRefType::get(newShape, elementType, {},
- memrefType.getMemorySpaceAsInt());
- }
- UnitExtentReplacementInfo info = {replacementType,
- AffineMap::get(indexingMap.getNumDims(),
- indexingMap.getNumSymbols(),
- newIndexExprs, context),
- ArrayAttr::get(context, reassociationMaps)};
+ // Return if the rank was not reduced.
+ if (origRank == static_cast<int64_t>(newShape.size()))
+ return std::nullopt;
+
+ UnitExtentReplacementInfo info = {
+ /*indexMap=*/AffineMap::get(indexingMap.getNumDims(),
+ indexingMap.getNumSymbols(), newIndexExprs,
+ context),
+ /*reassociation=*/reassociation, /*targetShape=*/newShape};
return info;
}
namespace {
-SmallVector<ReassociationExprs, 2>
-convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
- SmallVector<ReassociationExprs, 2> reassociationExprs;
- for (auto attr : affineMapArrayAttr)
- reassociationExprs.push_back(
- llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
- return reassociationExprs;
-}
-
/// Pattern to replace tensor/buffer operands/results that are unit extents.
struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
-
- // Return the original value if the type is unchanged, or reshape it. Return a
- // nullptr if this is an unsupported type.
- Value maybeExpand(Value result, Type origResultType,
- ArrayAttr reassociationMap, Location loc,
+ ReplaceUnitExtents(MLIRContext *ctx,
+ RankReductionStrategy rankReductionStrategy)
+ : OpRewritePattern<GenericOp>(ctx),
+ rankReductionStrategy(rankReductionStrategy) {}
+
+ // Expand the given value.
+ Value expandValue(Value result, Value origOutput,
+ ArrayRef<ReassociationIndices> reassociation, Location loc,
PatternRewriter &rewriter) const {
- if (origResultType == result.getType())
- return result;
- if (origResultType.isa<RankedTensorType>()) {
- return rewriter.create<tensor::ExpandShapeOp>(
- loc, origResultType, result,
- convertAffineMapArrayToExprs(reassociationMap));
- }
- if (origResultType.isa<MemRefType>()) {
- return rewriter.create<memref::ExpandShapeOp>(
- loc, origResultType, result,
- convertAffineMapArrayToExprs(reassociationMap));
+ // There are no results for memref outputs.
+ auto origResultType = origOutput.getType().cast<RankedTensorType>();
+ if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
+ unsigned rank = origResultType.getRank();
+ SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(rewriter, loc, origOutput);
+ SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ return rewriter.createOrFold<tensor::InsertSliceOp>(
+ loc, result, origOutput, offsets, sizes, strides);
}
- return nullptr;
- };
- // Return the original value if the type is unchanged, or reshape it. Return a
- // nullptr if this is an unsupported type.
- Value maybeCollapse(Value operand, Type newInputOutputType,
- ArrayAttr reassociationMap, Location loc,
- PatternRewriter &rewriter) const {
- auto operandType = operand.getType();
- if (operandType == newInputOutputType)
- return operand;
- if (operandType.isa<MemRefType>()) {
- return rewriter.create<memref::CollapseShapeOp>(
- loc, newInputOutputType, operand,
- convertAffineMapArrayToExprs(reassociationMap));
+ assert(rankReductionStrategy ==
+ RankReductionStrategy::ReassociativeReshape &&
+ "unknown rank reduction strategy");
+ return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
+ reassociation);
+ }
+
+ // Collapse the given value.
+ Value collapseValue(Value operand, ArrayRef<int64_t> targetShape,
+ ArrayRef<ReassociationIndices> reassociation,
+ Location loc, PatternRewriter &rewriter) const {
+ if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
+ if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
+ FailureOr<Value> rankReducingExtract =
+ memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
+ targetShape);
+ assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
+ return *rankReducingExtract;
+ }
+
+ assert(rankReductionStrategy ==
+ RankReductionStrategy::ReassociativeReshape &&
+ "unknown rank reduction strategy");
+ MemRefLayoutAttrInterface layout;
+ auto targetType =
+ MemRefType::get(targetShape, memrefType.getElementType(), layout,
+ memrefType.getMemorySpace());
+ return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
+ reassociation);
}
- if (operandType.isa<RankedTensorType>()) {
- return rewriter.create<tensor::CollapseShapeOp>(
- loc, newInputOutputType, operand,
- convertAffineMapArrayToExprs(reassociationMap));
+ if (auto tensorType = operand.getType().dyn_cast<RankedTensorType>()) {
+ if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
+ FailureOr<Value> rankReducingExtract =
+ tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
+ targetShape);
+ assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
+ return *rankReducingExtract;
+ }
+
+ assert(rankReductionStrategy ==
+ RankReductionStrategy::ReassociativeReshape &&
+ "unknown rank reduction strategy");
+ auto targetType =
+ RankedTensorType::get(targetShape, tensorType.getElementType());
+ return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
+ reassociation);
}
- return nullptr;
- };
+ llvm_unreachable("unsupported operand type");
+ }
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
@@ -495,71 +508,59 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
return failure();
MLIRContext *context = rewriter.getContext();
Location loc = genericOp.getLoc();
+ SmallVector<Value> oldOutputs(genericOp.getOutputs().begin(),
+ genericOp.getOutputs().end());
SmallVector<AffineMap> newIndexingMaps;
- SmallVector<ArrayAttr> reassociationMaps;
- SmallVector<Type> newInputOutputTypes;
- bool doCanonicalization = false;
+ SmallVector<SmallVector<ReassociationIndices>> reassociations;
+ SmallVector<SmallVector<int64_t>> targetShapes;
+ SmallVector<bool> collapsed;
for (OpOperand &opOperand : genericOp->getOpOperands()) {
auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context);
if (replacementInfo) {
- reassociationMaps.push_back(replacementInfo->reassociation);
+ reassociations.push_back(replacementInfo->reassociation);
newIndexingMaps.push_back(replacementInfo->indexMap);
- newInputOutputTypes.push_back(replacementInfo->type);
- doCanonicalization |=
- replacementInfo->type != opOperand.get().getType();
+ targetShapes.push_back(replacementInfo->targetShape);
+ collapsed.push_back(true);
} else {
- // If replaceUnitExtents cannot handle this case, maintain the same
- // type, indexing map, and create a set of mappings representing an
- // identity matrix.
- newInputOutputTypes.push_back(opOperand.get().getType());
+ // If replaceUnitExtents cannot handle this case (or no unit dim was
+ // removed), maintain the same type, indexing map, and create a set of
+ // mappings representing an identity matrix.
newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand));
- int64_t origRank = genericOp.getRank(&opOperand);
- auto maps = llvm::to_vector<8>(llvm::map_range(
- llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
- return AffineMapAttr::get(
- AffineMap::get(origRank, /*symbolCount = */ 0,
- getAffineDimExpr(dim, context), context));
- }));
- reassociationMaps.push_back(ArrayAttr::get(context, maps));
+ reassociations.emplace_back();
+ targetShapes.emplace_back();
+ collapsed.push_back(false);
}
}
- // If the indexing maps of the result operation are not invertible (i.e. not
- // legal), abort.
- if (!doCanonicalization ||
+ // Abort if the indexing maps of the result operation are not invertible
+ // (i.e. not legal) or if no dimension was reduced.
+ if (!llvm::any_of(collapsed, [](bool c) { return c; }) ||
!inversePermutation(concatAffineMaps(newIndexingMaps)))
return failure();
- // If any operand type change, insert a reshape to convert from the original
- // type to the new type.
- // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
- unsigned flattenedIdx = 0;
- auto insertReshapes = [&](ValueRange values) {
- SmallVector<Value, 4> res;
- res.reserve(values.size());
- for (auto operand : values) {
- auto reshapedValue =
- maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
- reassociationMaps[flattenedIdx], loc, rewriter);
- assert(reshapedValue &&
- "expected ranked MemRef or Tensor operand type");
- res.push_back(reshapedValue);
- ++flattenedIdx;
+ // Insert rank reductions.
+ SmallVector<Value> newOperands;
+ for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ int64_t idx = opOperand.getOperandNumber();
+ if (!collapsed[idx]) {
+ newOperands.push_back(opOperand.get());
+ continue;
}
- return res;
- };
-
- SmallVector<Value, 4> newInputs = insertReshapes(genericOp.getInputs());
- SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.getOutputs());
+ newOperands.push_back(collapseValue(opOperand.get(), targetShapes[idx],
+ reassociations[idx], loc, rewriter));
+ }
// If any result type changes, insert a reshape to convert from the original
// type to the new type.
- SmallVector<Type, 4> resultTypes;
+ ArrayRef<Value> newInputs =
+ ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
+ ArrayRef<Value> newOutputs =
+ ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
+ SmallVector<Type> resultTypes;
resultTypes.reserve(genericOp.getNumResults());
for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
- resultTypes.push_back(
- newInputOutputTypes[i + genericOp.getNumDpsInputs()]);
+ resultTypes.push_back(newOutputs[i].getType());
GenericOp replacementOp = rewriter.create<GenericOp>(
loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
genericOp.getIteratorTypesArray());
@@ -569,20 +570,24 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
// If any result tensor has a modified shape, then add reshape to recover
// the original shape.
- SmallVector<Value, 4> resultReplacements;
+ SmallVector<Value> resultReplacements;
for (const auto &result : llvm::enumerate(replacementOp.getResults())) {
unsigned index = result.index() + replacementOp.getNumDpsInputs();
- auto origResultType = genericOp.getResult(result.index()).getType();
-
- auto newResult = maybeExpand(result.value(), origResultType,
- reassociationMaps[index], loc, rewriter);
- assert(newResult &&
- "unexpected output type other than ranked MemRef or Tensor");
- resultReplacements.push_back(newResult);
+ Value origOutput = oldOutputs[result.index()];
+ if (!collapsed[result.index() + genericOp.getNumDpsInputs()]) {
+ resultReplacements.push_back(result.value());
+ continue;
+ }
+ resultReplacements.push_back(expandValue(
+ result.value(), origOutput, reassociations[index], loc, rewriter));
}
+
rewriter.replaceOp(genericOp, resultReplacements);
return success();
}
+
+private:
+ RankReductionStrategy rankReductionStrategy;
};
} // namespace
@@ -656,14 +661,16 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
-void mlir::linalg::populateFoldUnitExtentDimsPatterns(
+void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
- patterns.add<FoldUnitDimLoops, AddInitOperandsToInput, ReplaceUnitExtents,
- RankReducedExtractSliceOp,
- RankReducedInsertSliceOp<tensor::InsertSliceOp>,
- RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
- context);
+ patterns.add<ReplaceUnitExtents>(context,
+ RankReductionStrategy::ReassociativeReshape);
+ // TODO: Patterns unrelated to unit dim folding should be factored out.
+ patterns
+ .add<FoldUnitDimLoops, AddInitOperandsToInput, RankReducedExtractSliceOp,
+ RankReducedInsertSliceOp<tensor::InsertSliceOp>,
+ RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(context);
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
@@ -673,6 +680,14 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
+void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns(
+ RewritePatternSet &patterns) {
+ auto *context = patterns.getContext();
+ patterns.add<ReplaceUnitExtents>(context,
+ RankReductionStrategy::ExtractInsertSlice);
+ patterns.add<FoldUnitDimLoops>(context);
+}
+
namespace {
/// Pass that removes unit-extent dims within generic ops.
struct LinalgFoldUnitExtentDimsPass
@@ -681,10 +696,13 @@ struct LinalgFoldUnitExtentDimsPass
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
- if (foldOneTripLoopsOnly)
+ if (foldOneTripLoopsOnly) {
patterns.add<FoldUnitDimLoops, AddInitOperandsToInput>(context);
- else
- populateFoldUnitExtentDimsPatterns(patterns);
+ } else if (useRankReducingSlices) {
+ populateFoldUnitExtentDimsViaSlicesPatterns(patterns);
+ } else {
+ populateFoldUnitExtentDimsViaReshapesPatterns(patterns);
+ }
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index de9690fd2eb78..1d9011091c3a8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -109,6 +109,21 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext());
}
+SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
+ Location loc, Value value) {
+ auto memrefType = value.getType().cast<MemRefType>();
+ SmallVector<OpFoldResult> result;
+ for (int64_t i = 0; i < memrefType.getRank(); ++i) {
+ if (memrefType.isDynamicDim(i)) {
+ Value size = builder.create<memref::DimOp>(loc, value, i);
+ result.push_back(size);
+ } else {
+ result.push_back(builder.getIndexAttr(memrefType.getDimSize(i)));
+ }
+ }
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// Utility functions for propagating static information
//===----------------------------------------------------------------------===//
@@ -2912,6 +2927,35 @@ static MemRefType getCanonicalSubViewResultType(
mixedStrides);
}
+Value mlir::memref::createCanonicalRankReducingSubViewOp(
+ OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
+ auto memrefType = memref.getType().cast<MemRefType>();
+ unsigned rank = memrefType.getRank();
+ SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
+ SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
+ auto targetType = SubViewOp::inferRankReducedResultType(
+ targetShape, memrefType, offsets, sizes, strides)
+ .cast<MemRefType>();
+ return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
+ sizes, strides);
+}
+
+FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
+ Value value,
+ ArrayRef<int64_t> desiredShape) {
+ auto sourceMemrefType = value.getType().dyn_cast<MemRefType>();
+ assert(sourceMemrefType && "not a ranked memref type");
+ auto sourceShape = sourceMemrefType.getShape();
+ if (sourceShape.equals(desiredShape))
+ return value;
+ auto maybeRankReductionMask =
+ mlir::computeRankReductionMask(sourceShape, desiredShape);
+ if (!maybeRankReductionMask)
+ return failure();
+ return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
+}
+
/// Helper method to check if a `subview` operation is trivially a no-op. This
/// is the case if the all offsets are zero, all strides are 1, and the source
/// shape is same as the size of the subview. In such cases, the subview can
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index f9f06b0b55464..cfce3b3eea70e 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(linalg-fold-unit-extent-dims))" | FileCheck %s
+// RUN: mlir-opt %s -linalg-fold-unit-extent-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-fold-unit-extent-dims="use-rank-reducing-slices" -cse -split-input-file | FileCheck %s --check-prefix=CHECK-SLICES
#accesses = [
affine_map<(i, j, k, l, m) -> (i, k, m)>,
@@ -26,11 +27,57 @@ func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: t
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @drop_one_trip_loops
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
+// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+// CHECK-SLICES-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-SLICES-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
+// CHECK-SLICES-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-SLICES-LABEL: func @drop_one_trip_loops
+// CHECK-SLICES: tensor.extract_slice %{{.*}}[0, 0, 0] [%{{.*}}, 1, %{{.*}}] [1, 1, 1] : tensor<?x1x?xf32> to tensor<?x?xf32>
+// CHECK-SLICES: tensor.extract_slice %{{.*}}[0, 0, 0, 0, 0] [%{{.*}}, 1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1, 1] : tensor<?x1x?x1x?xf32> to tensor<?x?x?xf32>
+// CHECK-SLICES: linalg.generic
+// CHECK-SLICES-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
+// CHECK-SLICES-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SLICES: tensor.insert_slice %{{.*}} into %{{.*}}[0, 0, 0, 0, 0] [%{{.*}}, 1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
+
+
+// -----
+
+#accesses = [
+ affine_map<(i, j, k, l, m) -> (i, k, m)>,
+ affine_map<(i, j, k, l, m) -> ()>,
+ affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+ indexing_maps = #accesses,
+ library_call = "some_external_func"
+}
+
+func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
+ %0 = linalg.generic #trait
+ ins(%arg0, %arg1 : tensor<1x1x1xf32>, f32)
+ outs(%shape : tensor<?x1x?x1x?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
+ linalg.yield %arg3 : f32
+ } -> tensor<?x1x?x1x?xf32>
+ return %0 : tensor<?x1x?x1x?xf32>
+}
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, d0, 0)>
+// CHECK-LABEL: func @drop_one_trip_loops_all_ones
+// CHECK: tensor.collapse_shape %{{.*}} []
+// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel"]
+// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+
// -----
#accesses = [
@@ -871,3 +918,8 @@ func.func @drop_all_loops(%arg0 : memref<1x1xf32, 3>) -> memref<1x1xf32, 3>
// CHECK: memref.collapse_shape
// CHECK-SAME: [] : memref<1x1xf32, 3> into memref<f32, 3>
// CHECK: linalg.generic{{.*}}memref<f32, 3>
+
+// CHECK-SLICES-LABEL: func @drop_all_loops
+// CHECK-SLICES: memref.subview %{{.*}}[0, 0] [1, 1] [1, 1] : memref<1x1xf32, 3> to memref<f32, strided<[]>, 3>
+// CHECK-SLICES: linalg.generic{{.*}}memref<f32, strided<[]>, 3>
+
More information about the Mlir-commits
mailing list