[Mlir-commits] [mlir] [MLIR][NFC] Return MemRefType in memref.subview return type inference functions (PR #120024)
Tomás Longeri
llvmlistbot at llvm.org
Sun Dec 15 15:28:01 PST 2024
https://github.com/tlongeri created https://github.com/llvm/llvm-project/pull/120024
None
>From 556c406e32ecc8fbf43305f89fe36aed383ec801 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= <tlongeri at google.com>
Date: Sun, 15 Dec 2024 22:55:05 +0000
Subject: [PATCH] [MLIR][NFC] Return MemRefType in memref.subview return type
inference functions
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 36 +++++------
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 61 +++++++++----------
.../Transforms/IndependenceTransforms.cpp | 4 +-
.../Dialect/MemRef/Transforms/MultiBuffer.cpp | 12 ++--
.../BufferizableOpInterfaceImpl.cpp | 16 ++---
.../Transforms/VectorTransferOpTransforms.cpp | 4 +-
.../Vector/Transforms/VectorTransforms.cpp | 14 ++---
7 files changed, 70 insertions(+), 77 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index a0d8d34f38237a..4e31bb153c5e7e 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -2079,14 +2079,14 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
- static Type inferResultType(MemRefType sourceMemRefType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
- static Type inferResultType(MemRefType sourceMemRefType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ static MemRefType inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides);
+ static MemRefType inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<OpFoldResult> staticOffsets,
+ ArrayRef<OpFoldResult> staticSizes,
+ ArrayRef<OpFoldResult> staticStrides);
/// A rank-reducing result type can be inferred from the desired result
/// shape. Only the layout map is inferred.
@@ -2095,16 +2095,16 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
/// and the desired sizes. In case there are more "ones" among the sizes
/// than the difference in source/result rank, it is not clear which dims of
/// size one should be dropped.
- static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
- MemRefType sourceMemRefType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
- static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
- MemRefType sourceMemRefType,
- ArrayRef<OpFoldResult> staticOffsets,
- ArrayRef<OpFoldResult> staticSizes,
- ArrayRef<OpFoldResult> staticStrides);
+ static MemRefType inferRankReducedResultType(
+ ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides);
+ static MemRefType inferRankReducedResultType(
+ ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
+ ArrayRef<OpFoldResult> staticOffsets,
+ ArrayRef<OpFoldResult> staticSizes,
+ ArrayRef<OpFoldResult> staticStrides);
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 2219505c9b802f..12768f06fb1b0e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2697,10 +2697,10 @@ void SubViewOp::getAsmResultNames(
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
-Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides) {
+MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides) {
unsigned rank = sourceMemRefType.getRank();
(void)rank;
assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
@@ -2739,10 +2739,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
sourceMemRefType.getMemorySpace());
}
-Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
+MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2758,13 +2758,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
staticSizes, staticStrides);
}
-Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
- MemRefType sourceRankedTensorType,
- ArrayRef<int64_t> offsets,
- ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides) {
- auto inferredType = llvm::cast<MemRefType>(
- inferResultType(sourceRankedTensorType, offsets, sizes, strides));
+MemRefType SubViewOp::inferRankReducedResultType(
+ ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
+ ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> strides) {
+ MemRefType inferredType =
+ inferResultType(sourceRankedTensorType, offsets, sizes, strides);
assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
"expected ");
if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
@@ -2790,11 +2789,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
inferredType.getMemorySpace());
}
-Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
- MemRefType sourceRankedTensorType,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- ArrayRef<OpFoldResult> strides) {
+MemRefType SubViewOp::inferRankReducedResultType(
+ ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2821,8 +2819,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
- resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
- sourceMemRefType, staticOffsets, staticSizes, staticStrides));
+ resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
+ staticSizes, staticStrides);
}
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2987,8 +2985,8 @@ LogicalResult SubViewOp::verify() {
// Compute the expected result type, assuming that there are no rank
// reductions.
- auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
- baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
+ MemRefType expectedType = SubViewOp::inferResultType(
+ baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
// Verify all properties of a shaped type: rank, element type and dimension
// sizes. This takes into account potential rank reductions.
@@ -3070,8 +3068,8 @@ static MemRefType getCanonicalSubViewResultType(
MemRefType currentResultType, MemRefType currentSourceType,
MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
- auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
- sourceType, mixedOffsets, mixedSizes, mixedStrides));
+ MemRefType nonRankReducedType = SubViewOp::inferResultType(
+ sourceType, mixedOffsets, mixedSizes, mixedStrides);
FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
currentSourceType, currentResultType, mixedSizes);
if (failed(unusedDims))
@@ -3105,9 +3103,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
- auto targetType =
- llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
- targetShape, memrefType, offsets, sizes, strides));
+ MemRefType targetType = SubViewOp::inferRankReducedResultType(
+ targetShape, memrefType, offsets, sizes, strides);
return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
sizes, strides);
}
@@ -3251,11 +3248,11 @@ struct SubViewReturnTypeCanonicalizer {
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
// Infer a memref type without taking into account any rank reductions.
- auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
- mixedSizes, mixedStrides);
+ MemRefType resTy = SubViewOp::inferResultType(
+ op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
if (!resTy)
return {};
- MemRefType nonReducedType = cast<MemRefType>(resTy);
+ MemRefType nonReducedType = resTy;
// Directly return the non-rank reduced type if there are no dropped dims.
llvm::SmallBitVector droppedDims = op.getDroppedDims();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 1f06318cbd60e0..8ffea5a7839980 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -70,9 +70,9 @@ propagateSubViewOp(RewriterBase &rewriter,
UnrealizedConversionCastOp conversionOp, SubViewOp op) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
- auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
+ MemRefType newResultType = SubViewOp::inferRankReducedResultType(
op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
- op.getMixedSizes(), op.getMixedStrides()));
+ op.getMixedSizes(), op.getMixedStrides());
Value newSubview = rewriter.create<SubViewOp>(
op.getLoc(), newResultType, conversionOp.getOperand(0),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index bc0dd034f63851..c475d92e0658e5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -60,14 +60,13 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
// `subview(old_op)` is replaced by a new `subview(val)`.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(subviewUse);
- Type newType = memref::SubViewOp::inferRankReducedResultType(
+ MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
subviewUse.getStaticStrides());
Value newSubview = rewriter.create<memref::SubViewOp>(
- subviewUse->getLoc(), cast<MemRefType>(newType), val,
- subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
- subviewUse.getMixedStrides());
+ subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(),
+ subviewUse.getMixedSizes(), subviewUse.getMixedStrides());
// Ouch recursion ... is this really necessary?
replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
@@ -211,9 +210,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
// Strides is [1, 1 ... 1 ].
- auto dstMemref =
- cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
- originalShape, mbMemRefType, offsets, sizes, strides));
+ MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
+ originalShape, mbMemRefType, offsets, sizes, strides);
Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9797b73f534a96..35862c74c57552 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -407,10 +407,10 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
- return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ return memref::SubViewOp::inferRankReducedResultType(
extractSliceOp.getType().getShape(),
llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
- mixedStrides));
+ mixedStrides);
}
};
@@ -692,10 +692,10 @@ struct InsertSliceOpInterface
// Take a subview of the destination buffer.
auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
- auto subviewMemRefType =
- cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ MemRefType subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getShape(), dstMemrefType,
- mixedOffsets, mixedSizes, mixedStrides));
+ mixedOffsets, mixedSizes, mixedStrides);
Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
mixedStrides);
@@ -960,12 +960,12 @@ struct ParallelInsertSliceOpInterface
// Take a subview of the destination buffer.
auto destBufferType = cast<MemRefType>(destBuffer->getType());
- auto subviewMemRefType =
- cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ MemRefType subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
parallelInsertSliceOp.getMixedOffsets(),
parallelInsertSliceOp.getMixedSizes(),
- parallelInsertSliceOp.getMixedStrides()));
+ parallelInsertSliceOp.getMixedStrides());
Value subview = rewriter.create<memref::SubViewOp>(
parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
parallelInsertSliceOp.getMixedOffsets(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index bd5f06a3b46d42..b124ea32af1b32 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -265,9 +265,9 @@ static MemRefType dropUnitDims(MemRefType inputType,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
auto targetShape = getReducedShape(sizes);
- Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
+ MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
targetShape, inputType, offsets, sizes, strides);
- return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
+ return canonicalizeStridedLayout(rankReducedType);
}
/// Creates a rank-reducing memref.subview op that drops unit dims from its
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 20cd9cba6909a6..3f3e9ae9df2865 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1319,10 +1319,9 @@ class DropInnerMostUnitDimsTransferRead
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
- auto resultMemrefType =
- cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
- srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
- strides));
+ MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
+ srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
+ strides);
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
@@ -1410,10 +1409,9 @@ class DropInnerMostUnitDimsTransferWrite
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
- auto resultMemrefType =
- cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
- srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
- strides));
+ MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
+ srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
+ strides);
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
More information about the Mlir-commits
mailing list