[Mlir-commits] [mlir] [mlir][tensor] Simplify ExtractSliceOp::inferResultType (nfc) (PR #169313)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 24 02:42:27 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
The `offsets` and `strides` arguments are neither used nor required -
removed them and simplify this hook.
---
Full diff: https://github.com/llvm/llvm-project/pull/169313.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+5-13)
- (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+1-2)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+21-26)
- (modified) mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp (+1-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a4..ca2464f6272d3 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -486,17 +486,13 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
- /// offsets, sizes and strides. Special sentinels encode the dynamic case.
+ /// sizes. Special sentinels encode the dynamic case.
static RankedTensorType inferResultType(
RankedTensorType sourceTensorType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
+ ArrayRef<int64_t> staticSizes);
static RankedTensorType inferResultType(
RankedTensorType sourceTensorType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ ArrayRef<OpFoldResult> staticSizes);
/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
/// number of sizes), drop as many size 1 as needed to produce an inferred type
@@ -509,15 +505,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
+ ArrayRef<int64_t> staticSizes);
static RankedTensorType inferCanonicalRankReducedResultType(
unsigned resultRank,
RankedTensorType sourceRankedTensorType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ ArrayRef<OpFoldResult> staticSizes);
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 22690daa4f9e1..9e6c1e6036cba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
- reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
- strides));
+ reassociation->size(), sliceOp.getSourceType(), sizes));
Location loc = sliceOp.getLoc();
Value newSlice = tensor::ExtractSliceOp::create(
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 110bfdce72ea4..125db6249b23d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2293,9 +2293,9 @@ void ExtractSliceOp::getAsmResultNames(
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<int64_t> staticSizes) {
// An extract_slice op may specify only a leading subset of offset/sizes/
// strides in which case we complete with offset=0, sizes from memref type
// and strides=1.
@@ -2307,11 +2307,12 @@ RankedTensorType ExtractSliceOp::inferResultType(
}
// TODO: This uses neither offsets nor strides!
-RankedTensorType ExtractSliceOp::inferResultType(
- RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
+RankedTensorType
+ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
+ ArrayRef<OpFoldResult> sizes) {
SmallVector<int64_t> staticSizes;
std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
+
assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
@@ -2329,11 +2330,10 @@ RankedTensorType ExtractSliceOp::inferResultType(
/// To disambiguate, this function always drops the first 1 sizes occurrences.
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides) {
+ ArrayRef<int64_t> sizes) {
// Type inferred in the absence of rank-reducing behavior.
auto inferredType = llvm::cast<RankedTensorType>(
- inferResultType(sourceRankedTensorType, offsets, sizes, strides));
+ inferResultType(sourceRankedTensorType, sizes));
int rankDiff = inferredType.getRank() - desiredResultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
@@ -2352,16 +2352,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
- SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
- SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ ArrayRef<OpFoldResult> sizes) {
+ SmallVector<int64_t> staticSizes;
+ SmallVector<Value> dynamicSizes;
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
- staticStrides);
+ desiredResultRank, sourceRankedTensorType, staticSizes);
}
/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
@@ -2380,8 +2376,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
- resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
- sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
+ resultType = llvm::cast<RankedTensorType>(
+ ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
}
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2456,8 +2452,8 @@ LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
// Verify result type against inferred type.
- RankedTensorType expectedType = ExtractSliceOp::inferResultType(
- sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
+ RankedTensorType expectedType =
+ ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
SliceVerificationResult result = isRankReducedType(expectedType, getType());
if (result != SliceVerificationResult::Success)
return produceSliceErrorMsg(result, *this, expectedType);
@@ -2697,8 +2693,7 @@ struct SliceReturnTypeCanonicalizer {
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return ExtractSliceOp::inferCanonicalRankReducedResultType(
- op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
- mixedStrides);
+ op.getType().getRank(), op.getSourceType(), mixedSizes);
}
};
@@ -2839,8 +2834,8 @@ static SliceVerificationResult verifyInsertSliceOp(
ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type
// inference.
- RankedTensorType expected = ExtractSliceOp::inferResultType(
- dstType, staticOffsets, staticSizes, staticStrides);
+ RankedTensorType expected =
+ ExtractSliceOp::inferResultType(dstType, staticSizes);
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
@@ -2968,7 +2963,7 @@ class InsertSliceOpConstantArgumentFolder final
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
- mixedOffsets, mixedSizes, mixedStrides);
+ mixedSizes);
Value toInsert = insertSliceOp.getSource();
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 7ec61c7df81cf..421f9ab7ceff7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
// has no rank-reduction anymore are supported at the moment.
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
- srcType, extractSliceOp.getStaticOffsets(),
- extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
+ srcType, extractSliceOp.getStaticSizes());
if (nonReducingExtractType != resultType)
return failure();
``````````
</details>
https://github.com/llvm/llvm-project/pull/169313
More information about the Mlir-commits
mailing list