[Mlir-commits] [mlir] 9cbd136 - [mlir][NFC] Add a new getStridesAndOffset function
Quentin Colombet
llvmlistbot at llvm.org
Wed Dec 7 06:06:09 PST 2022
Author: Quentin Colombet
Date: 2022-12-07T13:58:28Z
New Revision: 9cbd136db4850e44e7c4e076d7c820829023c619
URL: https://github.com/llvm/llvm-project/commit/9cbd136db4850e44e7c4e076d7c820829023c619
DIFF: https://github.com/llvm/llvm-project/commit/9cbd136db4850e44e7c4e076d7c820829023c619.diff
LOG: [mlir][NFC] Add a new getStridesAndOffset function
The new function is a wrapper around the regular `getStridesAndOffset`
that offers a more compact way (as in writing less code) of getting the
relevant information.
This method is intended to be used only when it is known that the
LogicalResult of the regular `getStridesAndOffset` must be "succeeded".
This warpper will assert on that.
Differential Revision: https://reviews.llvm.org/D139529
Added:
Modified:
mlir/include/mlir/IR/BuiltinTypes.h
mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
mlir/lib/IR/BuiltinTypes.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 5aeea57c96716..8bdc672e77470 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -436,6 +436,10 @@ LogicalResult getStridesAndOffset(MemRefType t,
SmallVectorImpl<int64_t> &strides,
int64_t &offset);
+/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
+/// int64_t) that will assert if the logical result is not succeeded.
+std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t);
+
/// Return a version of `t` with identity layout if it can be determined
/// statically that the layout is the canonical contiguous strided layout.
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 9414b6fd3c413..d10c17e5fb5e6 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -52,11 +52,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
assert(type.hasStaticShape() && "unexpected dynamic shape");
// Extract all strides and offsets and verify they are static.
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto result = getStridesAndOffset(type, strides, offset);
- (void)result;
- assert(succeeded(result) && "unexpected failure in stride computation");
+ auto [strides, offset] = getStridesAndOffset(type);
assert(!ShapedType::isDynamic(offset) &&
"expected static offset");
assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 14799f865544f..6c28022e1a07f 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -69,11 +69,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter) const {
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto successStrides = getStridesAndOffset(type, strides, offset);
- assert(succeeded(successStrides) && "unexpected non-strided memref");
- (void)successStrides;
+ auto [strides, offset] = getStridesAndOffset(type);
MemRefDescriptor memRefDescriptor(memRefDesc);
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 223aec3d8cfe2..028cc53b18f94 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2538,11 +2538,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
assert(staticStrides.size() == rank && "staticStrides length mismatch");
// Extract source offset and strides.
- int64_t sourceOffset;
- SmallVector<int64_t, 4> sourceStrides;
- auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
- assert(succeeded(res) && "SubViewOp expected strided memref type");
- (void)res;
+ auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
// Compute target offset whose value is:
// `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
@@ -3098,12 +3094,8 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
AffineMap permutationMap) {
auto rank = memRefType.getRank();
auto originalSizes = memRefType.getShape();
- int64_t offset;
- SmallVector<int64_t, 4> originalStrides;
- auto res = getStridesAndOffset(memRefType, originalStrides, offset);
- assert(succeeded(res) &&
- originalStrides.size() == static_cast<unsigned>(rank));
- (void)res;
+ auto [originalStrides, offset] = getStridesAndOffset(memRefType);
+ assert(originalStrides.size() == static_cast<unsigned>(rank));
// Compute permuted sizes and strides.
SmallVector<int64_t> sizes(rank, 0);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 549108d3b8f4e..cac5490a409cb 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -60,14 +60,7 @@ struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
auto newExtractStridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
- SmallVector<int64_t> sourceStrides;
- int64_t sourceOffset;
-
- bool hasKnownStridesAndOffset =
- succeeded(getStridesAndOffset(sourceType, sourceStrides, sourceOffset));
- (void)hasKnownStridesAndOffset;
- assert(hasKnownStridesAndOffset &&
- "getStridesAndOffset must work on valid subviews");
+ auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
// Compute the new strides and offset from the base strides and offset:
// newStride#i = baseStride#i * subStride#i
@@ -265,13 +258,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
// Collect the statically known information about the original stride.
Value source = expandShape.getSrc();
auto sourceType = source.getType().cast<MemRefType>();
- SmallVector<int64_t> strides;
- int64_t offset;
- bool hasKnownStridesAndOffset =
- succeeded(getStridesAndOffset(sourceType, strides, offset));
- (void)hasKnownStridesAndOffset;
- assert(hasKnownStridesAndOffset &&
- "getStridesAndOffset must work on valid expand_shape");
+ auto [strides, offset] = getStridesAndOffset(sourceType);
OpFoldResult origStride =
ShapedType::isDynamic(strides[groupId])
@@ -414,13 +401,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
Value source = collapseShape.getSrc();
auto sourceType = source.getType().cast<MemRefType>();
- SmallVector<int64_t> strides;
- int64_t offset;
- bool hasKnownStridesAndOffset =
- succeeded(getStridesAndOffset(sourceType, strides, offset));
- (void)hasKnownStridesAndOffset;
- assert(hasKnownStridesAndOffset &&
- "getStridesAndOffset must work on valid collapse_shape");
+ auto [strides, offset] = getStridesAndOffset(sourceType);
SmallVector<OpFoldResult> collapsedStride;
int64_t innerMostDimForGroup = reassocGroup.back();
@@ -473,13 +454,7 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
// Collect statically known information.
- SmallVector<int64_t> strides;
- int64_t offset;
- bool hasKnownStridesAndOffset =
- succeeded(getStridesAndOffset(sourceType, strides, offset));
- (void)hasKnownStridesAndOffset;
- assert(hasKnownStridesAndOffset &&
- "getStridesAndOffset must work on valid reassociative_reshape_like");
+ auto [strides, offset] = getStridesAndOffset(sourceType);
MemRefType reshapeType = reshape.getResultType();
unsigned reshapeRank = reshapeType.getRank();
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 378b87d1284ca..4eef7dd496815 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -804,6 +804,16 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
return success();
}
+std::pair<SmallVector<int64_t>, int64_t>
+mlir::getStridesAndOffset(MemRefType t) {
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ LogicalResult status = getStridesAndOffset(t, strides, offset);
+ (void)status;
+ assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
+ return {strides, offset};
+}
+
//===----------------------------------------------------------------------===//
/// TupleType
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list