[Mlir-commits] [mlir] 741f8f2 - [mlir][Tensor][NFC] Better document rank-reducing behavior of ExtractSliceOp and cleanup
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jun 29 07:39:25 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-29T07:37:58-07:00
New Revision: 741f8f2bede58573560372bc219b2dec9a1d6643
URL: https://github.com/llvm/llvm-project/commit/741f8f2bede58573560372bc219b2dec9a1d6643
DIFF: https://github.com/llvm/llvm-project/commit/741f8f2bede58573560372bc219b2dec9a1d6643.diff
LOG: [mlir][Tensor][NFC] Better document rank-reducing behavior of ExtractSliceOp and cleanup
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index bd9ab3545ccb0..a6de3e9597d72 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -210,6 +210,25 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
flexibility allows to progressively drop unit dimensions while lowering
between
diff erent flavors of ops on that operate on tensors.
+ Verification vs Inference in the rank-reduced case:
+ ===================================================
+ Note that there may be multiple ways to infer a resulting rank-reduced type.
+ e.g. 1x6x1 could potentially rank-reduce to either 1x6 or 6x1 2-D shapes.
+
+ To disambiguate, the inference helpers `inferCanonicalRankReducedResultType`
+ only drop the first unit dimensions, in order:
+ e.g. 1x6x1 rank-reduced to 2-D will infer the 6x1 2-D shape, but not 1x6.
+
+ Verification however has access to result type and does not need to infer.
+ The verifier calls `isRankReducedType(getSource(), getResult())` to
+ determine whether the result type is rank-reduced from the source type.
+ This computes a so-called rank-reduction mask, consisting of dropped unit
+ dims, to map the rank-reduced type to the source type by dropping ones:
+ e.g. 1x6 is a rank-reduced version of 1x6x1 by mask {2}
+ 6x1 is a rank-reduced version of 1x6x1 by mask {0}
+ 1x2x1x4 is a rank-reduced version of 1x1x2x1x1x4x1 by mask {1, 4, 6}
+ (remaining common 1 dimensions are matched eagerly)
+
Example:
```
@@ -274,26 +293,43 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
return getResult().getType().cast<RankedTensorType>();
}
- /// An extract_slice result type can be fully inferred from the source type
- /// and the static representation of offsets, sizes and strides. Special
- /// sentinels encode the dynamic case.
+ /// Compute the rank-reduction mask that can be applied to map the source
+ /// tensor type to the result tensor type by dropping unit dims.
+ llvm::Optional<llvm::SmallDenseSet<unsigned>>
+ computeRankReductionMask() {
+ return ::mlir::computeRankReductionMask(getSourceType().getShape(),
+ getType().getShape());
+ };
+
+ /// An extract_slice result type can be inferred, when it is not
+ /// rank-reduced, from the source type and the static representation of
+ /// offsets, sizes and strides. Special sentinels encode the dynamic case.
static RankedTensorType inferResultType(
- RankedTensorType sourceRankedTensorType,
+ ShapedType sourceShapedTensorType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
static RankedTensorType inferResultType(
- RankedTensorType sourceRankedTensorType,
+ ShapedType sourceShapedTensorType,
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
ArrayRef<OpFoldResult> staticStrides);
- static RankedTensorType inferRankReducedResultType(
+
+ /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
+ /// number of sizes), drop as many size 1 as needed to produce an inferred type
+ /// with the desired rank.
+ ///
+ /// Note that there may be multiple ways to compute this rank-reduced type:
+ /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
+ ///
+ /// To disambiguate, this function always drops the first 1 sizes occurrences.
+ static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
- static RankedTensorType inferRankReducedResultType(
+ static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
ArrayRef<OpFoldResult> staticOffsets,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
index 10760685301fd..6c6bcabb499d9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
@@ -228,7 +228,7 @@ mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
return b.create<tensor::DimOp>(loc, target, dim).getResult();
return b.getIndexAttr(shapedType.getDimSize(dim));
});
- auto t = tensor::ExtractSliceOp::inferRankReducedResultType(
+ auto t = tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
insertOp.getSourceType().getRank(),
insertOp.getDest().getType().cast<RankedTensorType>(), mixedOffsets,
mixedSizes, mixedStrides);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 970b628a15c4b..e1e7ed76d23cc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -499,10 +499,11 @@ struct UseRankReducedExtractSliceOp
if (!reassociation ||
reassociation->size() == static_cast<size_t>(resultType.getRank()))
return failure();
- auto rankReducedType = tensor::ExtractSliceOp::inferRankReducedResultType(
- reassociation->size(), sliceOp.getSourceType(),
- offsets, sizes, strides)
- .cast<RankedTensorType>();
+ auto rankReducedType =
+ tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
+ reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
+ strides)
+ .cast<RankedTensorType>();
Location loc = sliceOp.getLoc();
Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 897af9fcee6f3..305e8f7e42394 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -957,25 +957,24 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
// ExtractSliceOp
//===----------------------------------------------------------------------===//
-/// An extract_slice op result type can be fully inferred from the source type
-/// and the static representation of offsets, sizes and strides. Special
-/// sentinels encode the dynamic case.
+/// An extract_slice result type can be inferred, when it is not
+/// rank-reduced, from the source type and the static representation of
+/// offsets, sizes and strides. Special sentinels encode the dynamic case.
RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceRankedTensorType, ArrayRef<int64_t> staticOffsets,
+ ShapedType sourceShapedTensorType, ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
// An extract_slice op may specify only a leading subset of offset/sizes/
// strides in which case we complete with offset=0, sizes from memref type and
// strides=1.
- unsigned rank = sourceRankedTensorType.getRank();
- (void)rank;
- assert(staticSizes.size() == rank &&
+ assert(static_cast<int64_t>(staticSizes.size()) ==
+ sourceShapedTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
return RankedTensorType::get(staticSizes,
- sourceRankedTensorType.getElementType());
+ sourceShapedTensorType.getElementType());
}
RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceRankedTensorType, ArrayRef<OpFoldResult> offsets,
+ ShapedType sourceShapedTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
@@ -985,26 +984,33 @@ RankedTensorType ExtractSliceOp::inferResultType(
ShapedType::kDynamicSize);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
ShapedType::kDynamicStrideOrOffset);
- return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
+ return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
staticSizes, staticStrides);
}
-/// An extract_slice op result type can be fully inferred from the source type
-/// and the static representation of offsets, sizes and strides. Special
-/// sentinels encode the dynamic case.
-RankedTensorType ExtractSliceOp::inferRankReducedResultType(
- unsigned resultRank, RankedTensorType sourceRankedTensorType,
+/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
+/// number of sizes), drop as many size 1 as needed to produce an inferred type
+/// with the desired rank.
+///
+/// Note that there may be multiple ways to compute this rank-reduced type:
+/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
+///
+/// To disambiguate, this function always drops the first 1 sizes occurrences.
+RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
+ unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
+ // Type inferred in the absence of rank-reducing behavior.
auto inferredType =
inferResultType(sourceRankedTensorType, offsets, sizes, strides)
.cast<RankedTensorType>();
- int rankDiff = inferredType.getRank() - resultRank;
+ int rankDiff = inferredType.getRank() - desiredResultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
llvm::SmallBitVector dimsToProject =
getPositionsOfShapeOne(rankDiff, shape);
SmallVector<int64_t> projectedShape;
+ // Best effort rank-reducing: drop 1s in order.
for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
if (!dimsToProject.test(pos))
projectedShape.push_back(shape[pos]);
@@ -1014,8 +1020,8 @@ RankedTensorType ExtractSliceOp::inferRankReducedResultType(
return inferredType;
}
-RankedTensorType ExtractSliceOp::inferRankReducedResultType(
- unsigned resultRank, RankedTensorType sourceRankedTensorType,
+RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
+ unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
@@ -1026,8 +1032,8 @@ RankedTensorType ExtractSliceOp::inferRankReducedResultType(
ShapedType::kDynamicSize);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
ShapedType::kDynamicStrideOrOffset);
- return ExtractSliceOp::inferRankReducedResultType(
- resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
+ return ExtractSliceOp::inferCanonicalRankReducedResultType(
+ desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
staticStrides);
}
@@ -1123,26 +1129,6 @@ LogicalResult ExtractSliceOp::verify() {
return produceSliceErrorMsg(result, *this, expectedType);
}
-/// Infer the canonical type of the result of an extract_slice op. Returns a
-/// type with rank `resultRank` that is either the rank of the rank-reduced
-/// type, or the non-rank-reduced type.
-static RankedTensorType
-getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
- ArrayRef<OpFoldResult> mixedOffsets,
- ArrayRef<OpFoldResult> mixedSizes,
- ArrayRef<OpFoldResult> mixedStrides) {
- auto resultType =
- ExtractSliceOp::inferRankReducedResultType(
- resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
- .cast<RankedTensorType>();
- if (resultType.getRank() != resultRank) {
- resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets,
- mixedSizes, mixedStrides)
- .cast<RankedTensorType>();
- }
- return resultType;
-}
-
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
ArrayRef<int64_t> resultShape = getType().getShape();
SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
@@ -1205,7 +1191,7 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
- // Any constant operand, just return to let SubViewOpConstantFolder kick in.
+ // Any constant operand, just return to let the constant folder kick in.
if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
return matchPattern(operand, matchConstantIndex());
}))
@@ -1219,10 +1205,11 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
return failure();
/// Deduce the type of the result to use for the canonicalized operation.
- RankedTensorType resultType = getCanonicalSliceResultType(
- sliceOp.getType().getRank(), sliceOp.getSourceType(),
- sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
- sliceOp.getMixedStrides());
+ RankedTensorType resultType =
+ ExtractSliceOp::inferCanonicalRankReducedResultType(
+ sliceOp.getType().getRank(), sliceOp.getSourceType(),
+ sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
+ sliceOp.getMixedStrides());
Value newSlice = rewriter.create<ExtractSliceOp>(
sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
@@ -1366,9 +1353,9 @@ struct SliceReturnTypeCanonicalizer {
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
- return getCanonicalSliceResultType(op.getType().getRank(),
- op.getSourceType(), mixedOffsets,
- mixedSizes, mixedStrides);
+ return ExtractSliceOp::inferCanonicalRankReducedResultType(
+ op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
+ mixedStrides);
}
};
@@ -1506,9 +1493,8 @@ verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
ArrayAttr staticStrides,
ShapedType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type inference.
- auto expected = ExtractSliceOp::inferRankReducedResultType(
- srcType.getRank(), dstType.cast<RankedTensorType>(),
- extractFromI64ArrayAttr(staticOffsets),
+ auto expected = ExtractSliceOp::inferResultType(
+ dstType, extractFromI64ArrayAttr(staticOffsets),
extractFromI64ArrayAttr(staticSizes),
extractFromI64ArrayAttr(staticStrides))
.cast<ShapedType>();
@@ -1600,7 +1586,7 @@ class InsertSliceOpConstantArgumentFolder final
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
// Create the new op in canonical form.
- auto sourceType = ExtractSliceOp::inferRankReducedResultType(
+ auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
mixedOffsets, mixedSizes, mixedStrides);
Value toInsert = insertSliceOp.getSource();
More information about the Mlir-commits
mailing list