[Mlir-commits] [mlir] 6aaa8f2 - [mlir][IR][NFC] Move free-standing functions to `MemRefType` (#123465)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 20 23:48:13 PST 2025
Author: Matthias Springer
Date: 2025-01-21T08:48:09+01:00
New Revision: 6aaa8f25b66dc1fef4e465f274ee40b82d632988
URL: https://github.com/llvm/llvm-project/commit/6aaa8f25b66dc1fef4e465f274ee40b82d632988
DIFF: https://github.com/llvm/llvm-project/commit/6aaa8f25b66dc1fef4e465f274ee40b82d632988.diff
LOG: [mlir][IR][NFC] Move free-standing functions to `MemRefType` (#123465)
Turn free-standing `MemRefType`-related helper functions in
`BuiltinTypes.h` into member functions.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/CommonTypeConstraints.td
mlir/lib/CAPI/IR/BuiltinTypes.cpp
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5910aa3f7f2dae..f5cf3dad75d9c2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -198,7 +198,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
assert(memrefType && "Incorrect use of getStaticStrides");
- auto [strides, offset] = getStridesAndOffset(memrefType);
+ auto [strides, offset] = memrefType.getStridesAndOffset();
// reuse the storage of ConstStridesAttr since strides from
// memref is not persistant
setConstStrides(strides);
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 19c5361124aacb..df1e02732617d2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -409,33 +409,6 @@ inline bool TensorType::classof(Type type) {
// Type Utilities
//===----------------------------------------------------------------------===//
-/// Returns the strides of the MemRef if the layout map is in strided form.
-/// MemRefs with a layout map in strided form include:
-/// 1. empty or identity layout map, in which case the stride information is
-/// the canonical form computed from sizes;
-/// 2. a StridedLayoutAttr layout;
-/// 3. any other layout that be converted into a single affine map layout of
-/// the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
-/// symbols.
-///
-/// A stride specification is a list of integer values that are either static
-/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
-/// the distance in the number of elements between successive entries along a
-/// particular dimension.
-LogicalResult getStridesAndOffset(MemRefType t,
- SmallVectorImpl<int64_t> &strides,
- int64_t &offset);
-
-/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
-/// int64_t) that will assert if the logical result is not succeeded.
-std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t);
-
-/// Return a version of `t` with identity layout if it can be determined
-/// statically that the layout is the canonical contiguous strided layout.
-/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
-/// `t` with simplified layout.
-MemRefType canonicalizeStridedLayout(MemRefType t);
-
/// Given MemRef `sizes` that are either static or dynamic, returns the
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
/// once a dynamic dimension is encountered, all canonical strides become
@@ -458,24 +431,6 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
MLIRContext *context);
-
-/// Return "true" if the layout for `t` is compatible with strided semantics.
-bool isStrided(MemRefType t);
-
-/// Return "true" if the last dimension of the given type has a static unit
-/// stride. Also return "true" for types with no strides.
-bool isLastMemrefDimUnitStride(MemRefType type);
-
-/// Return "true" if the last N dimensions of the given type are contiguous.
-///
-/// Examples:
-/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
-/// considering both _all_ and _only_ the trailing 3 dims,
-/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
-/// considering the trailing 3 dims.
-///
-bool trailingNDimsContiguous(MemRefType type, int64_t n);
-
} // namespace mlir
#endif // MLIR_IR_BUILTINTYPES_H
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4f09d2e41e7ceb..e5a2ae81da0c9a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -808,10 +808,52 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
+ /// Return "true" if the last N dimensions are contiguous.
+ ///
+ /// Examples:
+ /// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
+ /// considering both _all_ and _only_ the trailing 3 dims,
+ /// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
+ /// considering the trailing 3 dims.
+ ///
+ bool areTrailingDimsContiguous(int64_t n);
+
+ /// Return a version of this type with identity layout if it can be
+ /// determined statically that the layout is the canonical contiguous
+ /// strided layout. Otherwise pass the layout into `simplifyAffineMap`
+ /// and return a copy of this type with simplified layout.
+ MemRefType canonicalizeStridedLayout();
+
/// [deprecated] Returns the memory space in old raw integer representation.
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
+ /// Returns the strides of the MemRef if the layout map is in strided form.
+ /// MemRefs with a layout map in strided form include:
+ /// 1. empty or identity layout map, in which case the stride information
+ /// is the canonical form computed from sizes;
+ /// 2. a StridedLayoutAttr layout;
+ /// 3. any other layout that be converted into a single affine map layout
+ /// of the form `K + k0 * d0 + ... kn * dn`, where K and ki's are
+ /// constants or symbols.
+ ///
+ /// A stride specification is a list of integer values that are either
+ /// static or dynamic (encoded with ShapedType::kDynamic). Strides encode
+ /// the distance in the number of elements between successive entries along
+ /// a particular dimension.
+ LogicalResult getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
+ int64_t &offset);
+
+ /// Wrapper around getStridesAndOffset(SmallVectorImpl<int64_t>, int64_t)
+ /// that will assert if the logical result is not succeeded.
+ std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset();
+
+ /// Return "true" if the layout is compatible with strided semantics.
+ bool isStrided();
+
+ /// Return "true" if the last dimension has a static unit stride. Also
+ /// return "true" for types with no strides.
+ bool isLastDimUnitStride();
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index e752cdfb47fbb1..5ec995b3ae9771 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -820,7 +820,7 @@ class StaticShapeMemRefOf<list<Type> allowedTypes> :
def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;
// For a MemRefType, verify that it has strides.
-def HasStridesPred : CPred<[{ isStrided(::llvm::cast<::mlir::MemRefType>($_self)) }]>;
+def HasStridesPred : CPred<[{ ::llvm::cast<::mlir::MemRefType>($_self).isStrided() }]>;
class StridedMemRefOf<list<Type> allowedTypes> :
ConfinedType<MemRefOf<allowedTypes>, [HasStridesPred],
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 98ca9c3d239093..a080adf0f8103c 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -524,7 +524,7 @@ MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
int64_t *offset) {
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
SmallVector<int64_t> strides_;
- if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
+ if (failed(memrefType.getStridesAndOffset(strides_, *offset)))
return mlirLogicalResultFailure();
(void)std::copy(strides_.begin(), strides_.end(), strides);
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5d09d6f1d69523..51f5d7a161b903 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -192,7 +192,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
// Construct buffer descriptor from memref, attributes
int64_t offset = 0;
SmallVector<int64_t, 5> strides;
- if (failed(getStridesAndOffset(memrefType, strides, offset)))
+ if (failed(memrefType.getStridesAndOffset(strides, offset)))
return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
MemRefDescriptor memrefDescriptor(memref);
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 19c3ba1f950202..63f99eb744a83b 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -52,7 +52,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
assert(type.hasStaticShape() && "unexpected dynamic shape");
// Extract all strides and offsets and verify they are static.
- auto [strides, offset] = getStridesAndOffset(type);
+ auto [strides, offset] = type.getStridesAndOffset();
assert(!ShapedType::isDynamic(offset) && "expected static offset");
assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
"expected static strides");
@@ -193,7 +193,7 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
MemRefType type) {
// When we convert to LLVM, the input memref must have been normalized
// beforehand. Hence, this call is guaranteed to work.
- auto [strides, offsetCst] = getStridesAndOffset(type);
+ auto [strides, offsetCst] = type.getStridesAndOffset();
Value ptr = alignedPtr(builder, loc);
// For zero offsets, we already have the base pointer.
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index d551506485a454..a47a2872ceb073 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -62,7 +62,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter) const {
- auto [strides, offset] = getStridesAndOffset(type);
+ auto [strides, offset] = type.getStridesAndOffset();
MemRefDescriptor memRefDescriptor(memRefDesc);
// Use a canonical representation of the start address so that later
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 247a8ab28a44be..ea251e4564ea8a 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -485,7 +485,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
SmallVector<Type, 5>
LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
bool unpackAggregates) const {
- if (!isStrided(type)) {
+ if (!type.isStrided()) {
emitError(
UnknownLoc::get(type.getContext()),
"conversion to strided form failed either due to non-strided layout "
@@ -603,7 +603,7 @@ bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
int64_t offset = 0;
SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(memrefTy, strides, offset)))
+ if (failed(memrefTy.getStridesAndOffset(strides, offset)))
return false;
for (int64_t stride : strides)
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 86f687d7f2636e..f7542b8b3bc5c7 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1136,7 +1136,7 @@ struct MemRefReshapeOpLowering
// Extract the offset and strides from the type.
int64_t offset;
SmallVector<int64_t> strides;
- if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
+ if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
return rewriter.notifyMatchFailure(
reshapeOp, "failed to get stride and offset exprs");
@@ -1451,7 +1451,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
int64_t offset;
SmallVector<int64_t, 4> strides;
- auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
+ auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
if (failed(successStrides))
return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
assert(offset == 0 && "expected offset to be 0");
@@ -1560,7 +1560,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
auto memRefType = atomicOp.getMemRefType();
SmallVector<int64_t> strides;
int64_t offset;
- if (failed(getStridesAndOffset(memRefType, strides, offset)))
+ if (failed(memRefType.getStridesAndOffset(strides, offset)))
return failure();
auto dataPtr =
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 5b4414d67fdac0..eaefe9e3857933 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -132,7 +132,7 @@ static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
return 0;
int64_t offset = 0;
SmallVector<int64_t, 2> strides;
- if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
+ if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
strides.back() != 1)
return std::nullopt;
int64_t stride = strides[strides.size() - 2];
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d688d8e2ab6588..a1e21cb524bd9a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -91,7 +91,7 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
// Check if the last stride is non-unit and has a valid memory space.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
const LLVMTypeConverter &converter) {
- if (!isLastMemrefDimUnitStride(memRefType))
+ if (!memRefType.isLastDimUnitStride())
return failure();
if (failed(converter.getMemRefAddressSpace(memRefType)))
return failure();
@@ -1374,7 +1374,7 @@ static std::optional<SmallVector<int64_t, 4>>
computeContiguousStrides(MemRefType memRefType) {
int64_t offset;
SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(memRefType, strides, offset)))
+ if (failed(memRefType.getStridesAndOffset(strides, offset)))
return std::nullopt;
if (!strides.empty() && strides.back() != 1)
return std::nullopt;
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 01bc65c841e94c..22bf27d229ce5d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1650,7 +1650,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
return failure();
if (xferOp.getVectorType().getRank() != 1)
return failure();
- if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
+ if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
return failure(); // Handled by ConvertVectorToLLVM
// Loop bounds, step, state...
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 8041bdf7da19b3..d3229d2e912966 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -76,8 +76,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
// Validate further transfer op semantics.
SmallVector<int64_t> strides;
int64_t offset;
- if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
- strides.back() != 1)
+ if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
return rewriter.notifyMatchFailure(
xferOp, "Buffer must be contiguous in the innermost dimension");
@@ -105,7 +104,7 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
Operation::operand_range offsets) {
MemRefType srcTy = src.getType();
- auto [strides, offset] = getStridesAndOffset(srcTy);
+ auto [strides, offset] = srcTy.getStridesAndOffset();
xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape()) {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 5af0cb0c7ba1cc..271ca382e2f0ba 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -129,7 +129,7 @@ static bool staticallyOutOfBounds(OpType op) {
return false;
int64_t offset;
SmallVector<int64_t> strides;
- if (failed(getStridesAndOffset(bufferType, strides, offset)))
+ if (failed(bufferType.getStridesAndOffset(strides, offset)))
return false;
int64_t result = offset + op.getIndexOffset().value_or(0);
if (op.getSgprOffset()) {
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 4eac371d4c1ae4..4cb777b03b1963 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -53,8 +53,7 @@ FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
unsigned bytes = width >> 3;
int64_t offset;
SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(mType, strides, offset)) ||
- strides.back() != 1)
+ if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1)
return failure();
if (strides[preLast] == ShapedType::kDynamic) {
// Dynamic stride needs code to compute the stride at runtime.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index f1841b860ff81a..6be55a1d282240 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -42,8 +42,8 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
int64_t sourceOffset, targetOffset;
SmallVector<int64_t, 4> sourceStrides, targetStrides;
- if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
- failed(getStridesAndOffset(target, targetStrides, targetOffset)))
+ if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
+ failed(target.getStridesAndOffset(targetStrides, targetOffset)))
return false;
auto dynamicToStatic = [](int64_t a, int64_t b) {
return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 2502744cb3f580..ce0f112dc2dd22 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -29,7 +29,7 @@ using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
static bool hasFullyDynamicLayoutMap(MemRefType type) {
int64_t offset;
SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(type, strides, offset)))
+ if (failed(type.getStridesAndOffset(strides, offset)))
return false;
if (!llvm::all_of(strides, ShapedType::isDynamic))
return false;
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 49209229259a73..301066e7d3e1f8 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1903,7 +1903,7 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
auto operand = resMatrixType.getOperand();
auto srcMemrefType = llvm::cast<MemRefType>(srcType);
- if (!isLastMemrefDimUnitStride(srcMemrefType))
+ if (!srcMemrefType.isLastDimUnitStride())
return emitError(
"expected source memref most minor dim must have unit stride");
@@ -1923,7 +1923,7 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
auto dstMemrefType = llvm::cast<MemRefType>(dstType);
- if (!isLastMemrefDimUnitStride(dstMemrefType))
+ if (!dstMemrefType.isLastDimUnitStride())
return emitError(
"expected destination memref most minor dim must have unit stride");
diff --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
index a504101fb3f2fc..2afdeff3a7be11 100644
--- a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
@@ -67,7 +67,7 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
}
- auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
+ auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9aae46a5c288dc..4f75b7618d6367 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -163,7 +163,7 @@ static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
SmallVector<int64_t> strides;
int64_t offset;
LogicalResult hasStaticInformation =
- getStridesAndOffset(memrefType, strides, offset);
+ memrefType.getStridesAndOffset(strides, offset);
if (failed(hasStaticInformation))
return SmallVector<int64_t>();
return SmallVector<int64_t>(1, offset);
@@ -176,7 +176,7 @@ static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
SmallVector<int64_t> strides;
int64_t offset;
LogicalResult hasStaticInformation =
- getStridesAndOffset(memrefType, strides, offset);
+ memrefType.getStridesAndOffset(strides, offset);
if (failed(hasStaticInformation))
return SmallVector<int64_t>();
return strides;
@@ -663,8 +663,8 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
// Only fold casts between strided memref forms.
int64_t sourceOffset, resultOffset;
SmallVector<int64_t, 4> sourceStrides, resultStrides;
- if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
- failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+ if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
+ failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
return false;
// If cast is towards more static sizes along any dimension, don't fold.
@@ -708,8 +708,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (aT.getLayout() != bT.getLayout()) {
int64_t aOffset, bOffset;
SmallVector<int64_t, 4> aStrides, bStrides;
- if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
- failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
+ if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
+ failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
aStrides.size() != bStrides.size())
return false;
@@ -954,9 +954,9 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
SmallVector<int64_t> originalStrides, candidateStrides;
int64_t originalOffset, candidateOffset;
if (failed(
- getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
+ originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
failed(
- getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
+ reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
return failure();
// For memrefs, a dimension is truly dropped if its corresponding stride is
@@ -1903,7 +1903,7 @@ LogicalResult ReinterpretCastOp::verify() {
// identity layout.
int64_t resultOffset;
SmallVector<int64_t, 4> resultStrides;
- if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+ if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
return emitError("expected result type to have strided layout but found ")
<< resultType;
@@ -2223,7 +2223,7 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
ArrayRef<ReassociationIndices> reassociation) {
int64_t srcOffset;
SmallVector<int64_t> srcStrides;
- if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+ if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
return failure();
assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
@@ -2420,7 +2420,7 @@ computeCollapsedLayoutMap(MemRefType srcType,
int64_t srcOffset;
SmallVector<int64_t> srcStrides;
auto srcShape = srcType.getShape();
- if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+ if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
return failure();
// The result stride of a reassociation group is the stride of the last entry
@@ -2706,7 +2706,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
assert(staticStrides.size() == rank && "staticStrides length mismatch");
// Extract source offset and strides.
- auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
+ auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
// Compute target offset whose value is:
// `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
@@ -2912,8 +2912,8 @@ Value SubViewOp::getViewSource() { return getSource(); }
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
int64_t t1Offset, t2Offset;
SmallVector<int64_t> t1Strides, t2Strides;
- auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
- auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
+ auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
+ auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
}
@@ -2928,8 +2928,8 @@ static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
"incorrect number of dropped dims");
int64_t t1Offset, t2Offset;
SmallVector<int64_t> t1Strides, t2Strides;
- auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
- auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
+ auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
+ auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
if (failed(res1) || failed(res2))
return false;
for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
@@ -2980,7 +2980,7 @@ LogicalResult SubViewOp::verify() {
<< baseType << " and subview memref type " << subViewType;
// Verify that the base memref type has a strided layout map.
- if (!isStrided(baseType))
+ if (!baseType.isStrided())
return emitError("base type ") << baseType << " is not strided";
// Compute the expected result type, assuming that there are no rank
@@ -3261,7 +3261,7 @@ struct SubViewReturnTypeCanonicalizer {
return nonReducedType;
// Take the strides and offset from the non-rank reduced type.
- auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
+ auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
// Drop dims from shape and strides.
SmallVector<int64_t> targetShape;
@@ -3341,7 +3341,7 @@ void TransposeOp::getAsmResultNames(
static MemRefType inferTransposeResultType(MemRefType memRefType,
AffineMap permutationMap) {
auto originalSizes = memRefType.getShape();
- auto [originalStrides, offset] = getStridesAndOffset(memRefType);
+ auto [originalStrides, offset] = memRefType.getStridesAndOffset();
assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
// Compute permuted sizes and strides.
@@ -3400,10 +3400,10 @@ LogicalResult TransposeOp::verify() {
auto srcType = llvm::cast<MemRefType>(getIn().getType());
auto resultType = llvm::cast<MemRefType>(getType());
- auto canonicalResultType = canonicalizeStridedLayout(
- inferTransposeResultType(srcType, getPermutation()));
+ auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
+ .canonicalizeStridedLayout();
- if (canonicalizeStridedLayout(resultType) != canonicalResultType)
+ if (resultType.canonicalizeStridedLayout() != canonicalResultType)
return emitOpError("result type ")
<< resultType
<< " is not equivalent to the canonical transposed input type "
@@ -3483,7 +3483,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
// Get offset from old memref view type 'memRefType'.
int64_t oldOffset;
SmallVector<int64_t, 4> oldStrides;
- if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
+ if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
return failure();
assert(oldOffset == 0 && "Expected 0 offset");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 28f9061d9873b7..f58385a7777dbc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -632,7 +632,7 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
// Currently only handle innermost stride being 1, checking
SmallVector<int64_t> strides;
int64_t offset;
- if (failed(getStridesAndOffset(ty, strides, offset)))
+ if (failed(ty.getStridesAndOffset(strides, offset)))
return nullptr;
if (!strides.empty() && strides.back() != 1)
return nullptr;
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index aa008f8407b5d3..b69cbabe0dde97 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -68,9 +68,9 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
auto newExtractStridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
- auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
+ auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
#ifndef NDEBUG
- auto [resultStrides, resultOffset] = getStridesAndOffset(subview.getType());
+ auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset();
#endif // NDEBUG
// Compute the new strides and offset from the base strides and offset:
@@ -363,7 +363,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
// Collect the statically known information about the original stride.
Value source = expandShape.getSrc();
auto sourceType = cast<MemRefType>(source.getType());
- auto [strides, offset] = getStridesAndOffset(sourceType);
+ auto [strides, offset] = sourceType.getStridesAndOffset();
OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
? origStrides[groupId]
@@ -503,7 +503,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
Value source = collapseShape.getSrc();
auto sourceType = cast<MemRefType>(source.getType());
- auto [strides, offset] = getStridesAndOffset(sourceType);
+ auto [strides, offset] = sourceType.getStridesAndOffset();
SmallVector<OpFoldResult> groupStrides;
ArrayRef<int64_t> srcShape = sourceType.getShape();
@@ -528,7 +528,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
// but we still have to make the type system happy.
MemRefType collapsedType = collapseShape.getResultType();
auto [collapsedStrides, collapsedOffset] =
- getStridesAndOffset(collapsedType);
+ collapsedType.getStridesAndOffset();
int64_t finalStride = collapsedStrides[groupId];
if (ShapedType::isDynamic(finalStride)) {
// Look for a dynamic stride. At this point we don't know which one is
@@ -581,7 +581,7 @@ static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
// Collect statically known information.
- auto [strides, offset] = getStridesAndOffset(sourceType);
+ auto [strides, offset] = sourceType.getStridesAndOffset();
MemRefType reshapeType = reshape.getResultType();
unsigned reshapeRank = reshapeType.getRank();
@@ -1068,7 +1068,7 @@ class ExtractStridedMetadataOpCastFolder
: ofr;
};
- auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
+ auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
assert(sourceStrides.size() == rank && "unexpected number of strides");
// Register the new offset.
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 450bfa0cec0c7f..f93ae0a7a298f0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -91,7 +91,7 @@ struct CastOpInterface
// Get result offset and strides.
int64_t resultOffset;
SmallVector<int64_t> resultStrides;
- if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+ if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
return;
// Check offset.
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 6de744a7f75244..270b43100a3a74 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -27,7 +27,7 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
SmallVector<int64_t> strides;
int64_t offset;
- if (failed(getStridesAndOffset(type, strides, offset)))
+ if (failed(type.getStridesAndOffset(strides, offset)))
return false;
// MemRef is contiguous if outer dimensions are size-1 and inner
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 47d1b8492e06ec..ba86e8d6ceaf92 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -70,9 +70,9 @@ LogicalResult DeviceAsyncCopyOp::verify() {
auto srcMemref = llvm::cast<MemRefType>(getSrc().getType());
auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
- if (!isLastMemrefDimUnitStride(srcMemref))
+ if (!srcMemref.isLastDimUnitStride())
return emitError("source memref most minor dim must have unit stride");
- if (!isLastMemrefDimUnitStride(dstMemref))
+ if (!dstMemref.isLastDimUnitStride())
return emitError("destination memref most minor dim must have unit stride");
if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
return emitError()
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index f8c699c65fe49e..10bc1993ffd960 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -24,8 +24,8 @@ template <typename OpTy>
static bool isContiguousXferOp(OpTy op) {
return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
op.hasPureBufferSemantics() &&
- isLastMemrefDimUnitStride(
- cast<MemRefType>(nvgpu::getMemrefOperand(op).getType()));
+ cast<MemRefType>(nvgpu::getMemrefOperand(op).getType())
+ .isLastDimUnitStride();
}
/// Return "true" if the given op is a contiguous and suitable
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index c500815857ca5b..39cca7d363e0df 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -296,7 +296,7 @@ bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
// Check that the last dimension of the read is contiguous. Note that it is
// possible to expand support for this by scalarizing all the loads during
// conversion.
- auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
+ auto [strides, offset] = sourceType.getStridesAndOffset();
return strides.back() == 1;
}
@@ -320,6 +320,6 @@ bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
// Check that the last dimension of the target memref is contiguous. Note that
// it is possible to expand support for this by scalarizing all the stores
// during conversion.
- auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
+ auto [strides, offset] = sourceType.getStridesAndOffset();
return strides.back() == 1;
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 29f7e8afe0773b..c56dbcca2175d4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -206,7 +206,7 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
int64_t offset;
SmallVector<int64_t, 4> strides;
if (!memRefType.hasStaticShape() ||
- failed(getStridesAndOffset(memRefType, strides, offset)))
+ failed(memRefType.getStridesAndOffset(strides, offset)))
return std::nullopt;
// To get the size of the memref object in memory, the total size is the
@@ -1225,7 +1225,7 @@ Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
int64_t offset;
SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(baseType, strides, offset)) ||
+ if (failed(baseType.getStridesAndOffset(strides, offset)) ||
llvm::is_contained(strides, ShapedType::kDynamic) ||
ShapedType::isDynamic(offset)) {
return nullptr;
@@ -1256,7 +1256,7 @@ Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
int64_t offset;
SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(baseType, strides, offset)) ||
+ if (failed(baseType.getStridesAndOffset(strides, offset)) ||
llvm::is_contained(strides, ShapedType::kDynamic) ||
ShapedType::isDynamic(offset)) {
return nullptr;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 1abcacd6d6db3d..ed3ba321b37ab9 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -186,7 +186,7 @@ struct CollapseShapeOpInterface
// the source type.
SmallVector<int64_t> strides;
int64_t offset;
- if (failed(getStridesAndOffset(bufferType, strides, offset)))
+ if (failed(bufferType.getStridesAndOffset(strides, offset)))
return failure();
resultType = MemRefType::get(
{}, tensorResultType.getElementType(),
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 696d1e0f9b1e68..d8fc881911bae3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4974,7 +4974,7 @@ static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
(vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
return success();
- if (!isLastMemrefDimUnitStride(memRefTy))
+ if (!memRefTy.isLastDimUnitStride())
return op->emitOpError("most minor memref dim must have unit stride");
return success();
}
@@ -5789,7 +5789,7 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult TypeCastOp::verify() {
- MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
+ MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
if (!canonicalType.getLayout().isIdentity())
return emitOpError("expects operand to be a memref with identity layout");
if (!getResultMemRefType().getLayout().isIdentity())
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f9428a4ce28640..314dc44134e049 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -435,7 +435,7 @@ struct TransferReadToVectorLoadLowering
return rewriter.notifyMatchFailure(read, "not a memref source");
// Non-unit strides are handled by VectorToSCF.
- if (!isLastMemrefDimUnitStride(memRefType))
+ if (!memRefType.isLastDimUnitStride())
return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
// If there is broadcasting involved then we first load the unbroadcasted
@@ -588,7 +588,7 @@ struct TransferWriteToVectorStoreLowering
});
// Non-unit strides are handled by VectorToSCF.
- if (!isLastMemrefDimUnitStride(memRefType))
+ if (!memRefType.isLastDimUnitStride())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "most minor stride is not 1: " << write;
});
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b0892d16969d29..5871d6dd5b3e6d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -267,7 +267,7 @@ static MemRefType dropUnitDims(MemRefType inputType,
auto targetShape = getReducedShape(sizes);
Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
targetShape, inputType, offsets, sizes, strides);
- return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
+ return cast<MemRefType>(rankReducedType).canonicalizeStridedLayout();
}
/// Creates a rank-reducing memref.subview op that drops unit dims from its
@@ -283,8 +283,8 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
rewriter.getIndexAttr(1));
MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
- if (canonicalizeStridedLayout(resultType) ==
- canonicalizeStridedLayout(inputType))
+ if (resultType.canonicalizeStridedLayout() ==
+ inputType.canonicalizeStridedLayout())
return input;
return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
sizes, strides);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index ee622e886f6185..66c23dd6e74950 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -145,8 +145,8 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
return MemRefType();
int64_t aOffset, bOffset;
SmallVector<int64_t, 4> aStrides, bStrides;
- if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
- failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
+ if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
+ failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
aStrides.size() != bStrides.size())
return MemRefType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 21ec718efd6a7a..84c1deaebcd009 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1243,7 +1243,7 @@ static FailureOr<size_t>
getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
SmallVector<int64_t> srcStrides;
int64_t srcOffset;
- if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+ if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
return failure();
auto isUnitDim = [](VectorType type, int dim) {
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index e590d8c43c44b0..7b56cd0cf0e912 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -261,7 +261,7 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
ArrayRef<int64_t> vectorShape = vectorType.getShape();
auto vecRank = vectorType.getRank();
- if (!trailingNDimsContiguous(memrefType, vecRank))
+ if (!memrefType.areTrailingDimsContiguous(vecRank))
return false;
// Extract the trailing dims and strides of the input memref
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index bd1163bddf7ee0..3924d082f06280 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -645,24 +645,74 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
-//===----------------------------------------------------------------------===//
-// UnrankedMemRefType
-//===----------------------------------------------------------------------===//
+bool MemRefType::areTrailingDimsContiguous(int64_t n) {
+ if (!isLastDimUnitStride())
+ return false;
-unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
- return detail::getMemorySpaceAsInt(getMemorySpace());
+ auto memrefShape = getShape().take_back(n);
+ if (ShapedType::isDynamicShape(memrefShape))
+ return false;
+
+ if (getLayout().isIdentity())
+ return true;
+
+ int64_t offset;
+ SmallVector<int64_t> stridesFull;
+ if (!succeeded(getStridesAndOffset(stridesFull, offset)))
+ return false;
+ auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
+
+ if (strides.empty())
+ return true;
+
+ // Check whether strides match "flattened" dims.
+ SmallVector<int64_t> flattenedDims;
+ auto dimProduct = 1;
+ for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
+ dimProduct *= dim;
+ flattenedDims.push_back(dimProduct);
+ }
+
+ strides = strides.drop_back(1);
+ return llvm::equal(strides, llvm::reverse(flattenedDims));
}
-LogicalResult
-UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
- Type elementType, Attribute memorySpace) {
- if (!BaseMemRefType::isValidElementType(elementType))
- return emitError() << "invalid memref element type";
+MemRefType MemRefType::canonicalizeStridedLayout() {
+ AffineMap m = getLayout().getAffineMap();
- if (!isSupportedMemorySpace(memorySpace))
- return emitError() << "unsupported memory space Attribute";
+ // Already in canonical form.
+ if (m.isIdentity())
+ return *this;
- return success();
+ // Can't reduce to canonical identity form, return in canonical form.
+ if (m.getNumResults() > 1)
+ return *this;
+
+ // Corner-case for 0-D affine maps.
+ if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
+ if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0)))
+ if (cst.getValue() == 0)
+ return MemRefType::Builder(*this).setLayout({});
+ return *this;
+ }
+
+ // 0-D corner case for empty shape that still have an affine map. Example:
+ // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
+ // offset needs to remain, just return t.
+ if (getShape().empty())
+ return *this;
+
+ // If the canonical strided layout for the sizes of `t` is equal to the
+ // simplified layout of `t` we can just return an empty layout. Otherwise,
+ // just simplify the existing layout.
+ AffineExpr expr = makeCanonicalStridedLayoutExpr(getShape(), getContext());
+ auto simplifiedLayoutExpr =
+ simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
+ if (expr != simplifiedLayoutExpr)
+ return MemRefType::Builder(*this).setLayout(
+ AffineMapAttr::get(AffineMap::get(m.getNumDims(), m.getNumSymbols(),
+ simplifiedLayoutExpr)));
+ return MemRefType::Builder(*this).setLayout({});
}
// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
@@ -783,11 +833,10 @@ static LogicalResult getStridesAndOffset(MemRefType t,
return success();
}
-LogicalResult mlir::getStridesAndOffset(MemRefType t,
- SmallVectorImpl<int64_t> &strides,
- int64_t &offset) {
+LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
+ int64_t &offset) {
// Happy path: the type uses the strided layout directly.
- if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
+ if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout())) {
llvm::append_range(strides, strided.getStrides());
offset = strided.getOffset();
return success();
@@ -797,14 +846,14 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
// convertible to affine maps.
AffineExpr offsetExpr;
SmallVector<AffineExpr, 4> strideExprs;
- if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
+ if (failed(::getStridesAndOffset(*this, strideExprs, offsetExpr)))
return failure();
- if (auto cst = dyn_cast<AffineConstantExpr>(offsetExpr))
+ if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
offset = cst.getValue();
else
offset = ShapedType::kDynamic;
for (auto e : strideExprs) {
- if (auto c = dyn_cast<AffineConstantExpr>(e))
+ if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
strides.push_back(c.getValue());
else
strides.push_back(ShapedType::kDynamic);
@@ -812,16 +861,49 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
return success();
}
-std::pair<SmallVector<int64_t>, int64_t>
-mlir::getStridesAndOffset(MemRefType t) {
+std::pair<SmallVector<int64_t>, int64_t> MemRefType::getStridesAndOffset() {
SmallVector<int64_t> strides;
int64_t offset;
- LogicalResult status = getStridesAndOffset(t, strides, offset);
+ LogicalResult status = getStridesAndOffset(strides, offset);
(void)status;
assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
return {strides, offset};
}
+bool MemRefType::isStrided() {
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto res = getStridesAndOffset(strides, offset);
+ return succeeded(res);
+}
+
+bool MemRefType::isLastDimUnitStride() {
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ auto successStrides = getStridesAndOffset(strides, offset);
+ return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
+}
+
+//===----------------------------------------------------------------------===//
+// UnrankedMemRefType
+//===----------------------------------------------------------------------===//
+
+unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
+ return detail::getMemorySpaceAsInt(getMemorySpace());
+}
+
+LogicalResult
+UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
+ Type elementType, Attribute memorySpace) {
+ if (!BaseMemRefType::isValidElementType(elementType))
+ return emitError() << "invalid memref element type";
+
+ if (!isSupportedMemorySpace(memorySpace))
+ return emitError() << "unsupported memory space Attribute";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
/// TupleType
//===----------------------------------------------------------------------===//
@@ -849,49 +931,6 @@ size_t TupleType::size() const { return getImpl()->size(); }
// Type Utilities
//===----------------------------------------------------------------------===//
-/// Return a version of `t` with identity layout if it can be determined
-/// statically that the layout is the canonical contiguous strided layout.
-/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
-/// `t` with simplified layout.
-/// If `t` has multiple layout maps or a multi-result layout, just return `t`.
-MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
- AffineMap m = t.getLayout().getAffineMap();
-
- // Already in canonical form.
- if (m.isIdentity())
- return t;
-
- // Can't reduce to canonical identity form, return in canonical form.
- if (m.getNumResults() > 1)
- return t;
-
- // Corner-case for 0-D affine maps.
- if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
- if (auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
- if (cst.getValue() == 0)
- return MemRefType::Builder(t).setLayout({});
- return t;
- }
-
- // 0-D corner case for empty shape that still have an affine map. Example:
- // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
- // offset needs to remain, just return t.
- if (t.getShape().empty())
- return t;
-
- // If the canonical strided layout for the sizes of `t` is equal to the
- // simplified layout of `t` we can just return an empty layout. Otherwise,
- // just simplify the existing layout.
- AffineExpr expr =
- makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
- auto simplifiedLayoutExpr =
- simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
- if (expr != simplifiedLayoutExpr)
- return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
- m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
- return MemRefType::Builder(t).setLayout({});
-}
-
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
ArrayRef<AffineExpr> exprs,
MLIRContext *context) {
@@ -932,49 +971,3 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
exprs.push_back(getAffineDimExpr(dim, context));
return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
}
-
-bool mlir::isStrided(MemRefType t) {
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto res = getStridesAndOffset(t, strides, offset);
- return succeeded(res);
-}
-
-bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
- int64_t offset;
- SmallVector<int64_t> strides;
- auto successStrides = getStridesAndOffset(type, strides, offset);
- return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
-}
-
-bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
- if (!isLastMemrefDimUnitStride(type))
- return false;
-
- auto memrefShape = type.getShape().take_back(n);
- if (ShapedType::isDynamicShape(memrefShape))
- return false;
-
- if (type.getLayout().isIdentity())
- return true;
-
- int64_t offset;
- SmallVector<int64_t> stridesFull;
- if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
- return false;
- auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
-
- if (strides.empty())
- return true;
-
- // Check whether strides match "flattened" dims.
- SmallVector<int64_t> flattenedDims;
- auto dimProduct = 1;
- for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
- dimProduct *= dim;
- flattenedDims.push_back(dimProduct);
- }
-
- strides = strides.drop_back(1);
- return llvm::equal(strides, llvm::reverse(flattenedDims));
-}
diff --git a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp
index 968e10b8d0cab6..f17f5db2fa22fe 100644
--- a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp
+++ b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp
@@ -35,7 +35,7 @@ void TestMemRefStrideCalculation::runOnOperation() {
auto memrefType = cast<MemRefType>(allocOp.getResult().getType());
int64_t offset;
SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(memrefType, strides, offset))) {
+ if (failed(memrefType.getStridesAndOffset(strides, offset))) {
llvm::outs() << "MemRefType " << memrefType << " cannot be converted to "
<< "strided form\n";
return;
More information about the Mlir-commits
mailing list