[Mlir-commits] [mlir] c582688 - [MLIR][tensor] Simplify ExtractSliceOp::inferResultType (nfc) (#169313)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 25 08:14:55 PST 2025
Author: Andrzej WarzyĆski
Date: 2025-11-25T16:14:51Z
New Revision: c582688b6912c615da1d08630c178dd3d0072aeb
URL: https://github.com/llvm/llvm-project/commit/c582688b6912c615da1d08630c178dd3d0072aeb
DIFF: https://github.com/llvm/llvm-project/commit/c582688b6912c615da1d08630c178dd3d0072aeb.diff
LOG: [MLIR][tensor] Simplify ExtractSliceOp::inferResultType (nfc) (#169313)
The `offsets` and `strides` arguments are neither used nor required -
removed them and simplify this hook.
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3e93e58575e65..ac40d5e454281 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -490,17 +490,13 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
- /// offsets, sizes and strides. Special sentinels encode the dynamic case.
+ /// sizes. Special sentinels encode the dynamic case.
static RankedTensorType inferResultType(
RankedTensorType sourceTensorType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
+ ArrayRef<int64_t> staticSizes);
static RankedTensorType inferResultType(
RankedTensorType sourceTensorType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ ArrayRef<OpFoldResult> staticSizes);
/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
/// number of sizes), drop as many size 1 as needed to produce an inferred type
@@ -513,15 +509,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
+ ArrayRef<int64_t> staticSizes);
static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ ArrayRef<OpFoldResult> staticSizes);
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 22690daa4f9e1..9e6c1e6036cba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
- reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
- strides));
+ reassociation->size(), sliceOp.getSourceType(), sizes));
Location loc = sliceOp.getLoc();
Value newSlice = tensor::ExtractSliceOp::create(
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b1893f0868ac5..5a58d7cbed30f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2291,9 +2291,9 @@ void ExtractSliceOp::getAsmResultNames(
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<int64_t> staticSizes) {
// An extract_slice op may specify only a leading subset of offset/sizes/
// strides in which case we complete with offset=0, sizes from memref type
// and strides=1.
@@ -2305,11 +2305,12 @@ RankedTensorType ExtractSliceOp::inferResultType(
}
// TODO: This uses neither offsets nor strides!
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<OpFoldResult> sizes) {
SmallVector<int64_t> staticSizes;
std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
+
assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
@@ -2327,11 +2328,10 @@ RankedTensorType ExtractSliceOp::inferResultType(
/// To disambiguate, this function always drops the first 1 sizes occurrences.
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides) {
+ ArrayRef<int64_t> sizes) {
// Type inferred in the absence of rank-reducing behavior.
auto inferredType = llvm::cast<RankedTensorType>(
- inferResultType(sourceRankedTensorType, offsets, sizes, strides));
+ inferResultType(sourceRankedTensorType, sizes));
int rankDiff = inferredType.getRank() - desiredResultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
@@ -2350,16 +2350,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
- SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
- SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ ArrayRef<OpFoldResult> sizes) {
+ SmallVector<int64_t> staticSizes;
+ SmallVector<Value> dynamicSizes;
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
- staticStrides);
+ desiredResultRank, sourceRankedTensorType, staticSizes);
}
/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
@@ -2378,8 +2374,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
- resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
- sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
+ resultType = llvm::cast<RankedTensorType>(
+ ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
}
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2454,8 +2450,8 @@ LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
// Verify result type against inferred type.
- RankedTensorType expectedType = ExtractSliceOp::inferResultType(
- sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
+ RankedTensorType expectedType =
+ ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
SliceVerificationResult result = isRankReducedType(expectedType, getType());
if (result != SliceVerificationResult::Success)
return produceSliceErrorMsg(result, *this, expectedType);
@@ -2695,8 +2691,7 @@ struct SliceReturnTypeCanonicalizer {
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
- mixedStrides);
+ op.getType().getRank(), op.getSourceType(), mixedSizes);
}
};
@@ -2837,8 +2832,8 @@ static SliceVerificationResult verifyInsertSliceOp(
ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type
// inference.
- RankedTensorType expected = ExtractSliceOp::inferResultType(
- dstType, staticOffsets, staticSizes, staticStrides);
+ RankedTensorType expected =
+ ExtractSliceOp::inferResultType(dstType, staticSizes);
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
@@ -2966,7 +2961,7 @@ class InsertSliceOpConstantArgumentFolder final
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
- mixedOffsets, mixedSizes, mixedStrides);
+ mixedSizes);
Value toInsert = insertSliceOp.getSource();
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index aa52c0c138d0b..a53af98474245 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
// has no rank-reduction anymore are supported at the moment.
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
- srcType, extractSliceOp.getStaticOffsets(),
- extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
+ srcType, extractSliceOp.getStaticSizes());
if (nonReducingExtractType != resultType)
return failure();
More information about the Mlir-commits
mailing list