[Mlir-commits] [mlir] 3c9aa10 - Foo
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Aug 4 04:25:18 PDT 2023
Author: Nicolas Vasilache
Date: 2023-08-04T11:06:17Z
New Revision: 3c9aa10c57cf0833ff108ecf9ffbb512bd96cc89
URL: https://github.com/llvm/llvm-project/commit/3c9aa10c57cf0833ff108ecf9ffbb512bd96cc89
DIFF: https://github.com/llvm/llvm-project/commit/3c9aa10c57cf0833ff108ecf9ffbb512bd96cc89.diff
LOG: Foo
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
index bc45fcd84e0982..b6e10b29fbb20c 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h
@@ -14,6 +14,7 @@
#ifndef MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H
#define MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H
+#include "mlir/IR/BuiltinTypes.h"
#include "llvm/IR/DataLayout.h"
namespace mlir {
@@ -66,6 +67,9 @@ class LowerToLLVMOptions {
/// Get the index bitwidth.
unsigned getIndexBitwidth() const { return indexBitwidth; }
+ /// Hook to customize the conversion of MemRefType to LLVMType.
+ llvm::function_ref<Type(MemRefType)> memrefIndexTypeConverter = nullptr;
+
private:
unsigned indexBitwidth;
};
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 0aee13818df4d5..52bba8b0e97ef5 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -50,6 +50,11 @@ class ConvertToLLVMPattern : public ConversionPattern {
/// defined by the used type converter.
Type getIndexType() const;
+ /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
+ /// defined by the used type converter and matching the index type needed for
+ /// MemRefType `t`.
+ Type getIndexTypeMatchingMemRef(MemRefType t) const;
+
/// Gets the MLIR type wrapping the LLVM integer type whose bit width
/// corresponds to that of a LLVM pointer type.
Type getIntPtrType(unsigned addressSpace = 0) const;
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 79a68e875f045e..970fde18ad7bad 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -132,14 +132,20 @@ class LLVMTypeConverter : public TypeConverter {
/// integer type with the size configured for this type converter.
Type getIndexType();
- /// Returns true if using opaque pointers was enabled in the lowering options.
+ /// Gets the LLVM representation of the index type that matches the MemRefType
+ /// `t`. The returned type is an integer type with the size configured for
+ /// this type converter.
+ Type getIndexTypeMatchingMemRef(MemRefType t);
+
+ /// Returns true if using opaque pointers was enabled in the lowering
+ /// options.
bool useOpaquePointers() const { return getOptions().useOpaquePointers; }
/// Creates an LLVM pointer type with the given element type and address
/// space.
- /// This function is meant to be used in code supporting both typed and opaque
- /// pointers, as it will create an opaque pointer with the given address space
- /// if opaque pointers are enabled in the lowering options.
+ /// This function is meant to be used in code supporting both typed and
+ /// opaque pointers, as it will create an opaque pointer with the given
+ /// address space if opaque pointers are enabled in the lowering options.
LLVM::LLVMPointerType getPointerType(Type elementType,
unsigned addressSpace = 0);
@@ -170,13 +176,13 @@ class LLVMTypeConverter : public TypeConverter {
private:
/// Convert a function type. The arguments and results are converted one by
- /// one. Additionally, if the function returns more than one value, pack the
- /// results into an LLVM IR structure type so that the converted function type
- /// returns at most one result.
+ /// one. Additionally, if the function returns more than one value, pack
+ /// the results into an LLVM IR structure type so that the converted
+ /// function type returns at most one result.
Type convertFunctionType(FunctionType type);
- /// Convert the index type. Uses llvmModule data layout to create an integer
- /// of the pointer bitwidth.
+ /// Convert the index type. Uses llvmModule data layout to create an
+ /// integer of the pointer bitwidth.
Type convertIndexType(IndexType type);
/// Convert an integer type `i*` to `!llvm<"i*">`.
@@ -184,12 +190,13 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a floating point type: `f16` to `f16`, `f32` to
/// `f32` and `f64` to `f64`. `bf16` is not supported
- /// by LLVM. 8-bit float types are converted to 8-bit integers as this is how
- /// all LLVM backends that support them currently represent them.
+ /// by LLVM. 8-bit float types are converted to 8-bit integers as this is
+ /// how all LLVM backends that support them currently represent them.
Type convertFloatType(FloatType type);
- /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`,
- /// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to
+ /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half
+ /// }">`, `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>`
+ /// to
/// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported.
Type convertComplexType(ComplexType type);
@@ -197,10 +204,10 @@ class LLVMTypeConverter : public TypeConverter {
Type convertMemRefType(MemRefType type);
/// Convert a memref type into a list of LLVM IR types that will form the
- /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
- /// arrays in the descriptors are unpacked to individual index-typed elements,
- /// else they are are kept as rank-sized arrays of index type. In particular,
- /// the list will contain:
+ /// memref descriptor. If `unpackAggregates` is true the `sizes` and
+ /// `strides` arrays in the descriptors are unpacked to individual
+ /// index-typed elements, else they are are kept as rank-sized arrays of
+ /// index type. In particular, the list will contain:
/// - two pointers to the memref element type, followed by
/// - an index-typed offset, followed by
/// - (if unpackAggregates = true)
@@ -220,9 +227,9 @@ class LLVMTypeConverter : public TypeConverter {
SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
bool unpackAggregates);
- /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
- /// that will form the unranked memref descriptor. In particular, this list
- /// contains:
+ /// Convert an unranked memref type into a list of non-aggregate LLVM IR
+ /// types that will form the unranked memref descriptor. In particular, this
+ /// list contains:
/// - an integer rank, followed by
/// - a pointer to the memref descriptor struct.
/// For example, memref<*xf32> is converted to the following list:
diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
index 495c4d63986f80..b35b77ccee6371 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
@@ -43,7 +43,8 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
MemRefType memRefType = op.getType();
Value alignment;
if (auto alignmentAttr = op.getAlignment()) {
- Type indexType = getIndexType();
+ Type indexType =
+ ConvertToLLVMPattern::getIndexTypeMatchingMemRef(memRefType);
alignment =
createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
} else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 842e63aff7bd9f..f779c0cb4592f4 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -52,6 +52,14 @@ def GPU_Dialect : Dialect {
/// Returns the numeric value used to identify the private memory address
/// space.
static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; }
+
+ /// Return true if the given MemRefType has an address space that is a
+ /// gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ static bool hasSharedMemoryAddressSpace(MemRefType type);
+
+ /// Return true if the given Attribute has matches is a gpu::AddressSpaceAttr
+ /// attribute with value 'workgroup`.
+ static bool isSharedMemoryAddressSpace(Attribute type);
}];
let dependentDialects = ["arith::ArithDialect"];
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 9993c093badc16..0a197b9ab01004 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -67,7 +67,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
protected:
Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
MemRefType type, MemRefDescriptor desc) const {
- Type indexType = ConvertToLLVMPattern::getIndexType();
+ Type indexType = ConvertToLLVMPattern::getIndexTypeMatchingMemRef(type);
return type.hasStaticShape()
? ConvertToLLVMPattern::createIndexAttrConstant(
rewriter, loc, indexType, type.getNumElements())
@@ -654,10 +654,16 @@ class ConvertSDDMMOpToGpuRuntimeCallPattern
} // namespace
+static IntegerType getIndexTypeForMemRef(MemRefType t) {
+ int64_t numBits = gpu::GPUDialect::hasSharedMemoryAddressSpace(t) ? 32 : 64;
+ return IntegerType::get(t.getContext(), numBits);
+}
+
void GpuToLLVMConversionPass::runOnOperation() {
LowerToLLVMOptions options(&getContext());
options.useOpaquePointers = useOpaquePointers;
options.useBarePtrCallConv = hostBarePtrCallConv;
+ options.memrefIndexTypeConverter = getIndexTypeForMemRef;
LLVMTypeConverter converter(&getContext(), options);
RewritePatternSet patterns(&getContext());
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 1699172eb9dab3..72c6c93734fd10 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -19,6 +19,20 @@ using namespace mlir;
// ConvertToLLVMPattern
//===----------------------------------------------------------------------===//
+static Value convertToDesiredIndexType(OpBuilder &b, Location loc, Value src,
+ Type desiredIndexType) {
+ assert(src.getType().isIntOrIndex() && !src.getType().isIndex() &&
+ "expected int type");
+ assert(desiredIndexType.isIntOrIndex() && !desiredIndexType.isIndex() &&
+ "expected int type");
+ if (src.getType() == desiredIndexType)
+ return src;
+ if (src.getType().getIntOrFloatBitWidth() <
+ desiredIndexType.getIntOrFloatBitWidth())
+ return b.create<LLVM::SExtOp>(loc, desiredIndexType, src);
+ return b.create<LLVM::TruncOp>(loc, desiredIndexType, src);
+}
+
ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
MLIRContext *context,
LLVMTypeConverter &typeConverter,
@@ -38,6 +52,10 @@ Type ConvertToLLVMPattern::getIndexType() const {
return getTypeConverter()->getIndexType();
}
+Type ConvertToLLVMPattern::getIndexTypeMatchingMemRef(MemRefType t) const {
+ return getTypeConverter()->getIndexTypeMatchingMemRef(t);
+}
+
Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
return IntegerType::get(&getTypeConverter()->getContext(),
getTypeConverter()->getPointerBitwidth(addressSpace));
@@ -74,7 +92,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
Value base =
memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
- Type indexType = getIndexType();
+ Type indexType = getIndexTypeMatchingMemRef(type);
Value index;
for (int i = 0, e = indices.size(); i < e; ++i) {
Value increment = indices[i];
@@ -83,8 +101,11 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
ShapedType::isDynamic(strides[i])
? memRefDescriptor.stride(rewriter, loc, i)
: createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
+ increment =
+ convertToDesiredIndexType(rewriter, loc, increment, indexType);
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
}
+ increment = convertToDesiredIndexType(rewriter, loc, increment, indexType);
index =
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
}
@@ -127,7 +148,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
sizes.reserve(memRefType.getRank());
unsigned dynamicIndex = 0;
- Type indexType = getIndexType();
+ Type indexType = getIndexTypeMatchingMemRef(memRefType);
for (int64_t size : memRefType.getShape()) {
sizes.push_back(
size == ShapedType::kDynamic
@@ -194,7 +215,7 @@ Value ConvertToLLVMPattern::getNumElements(
static_cast<ssize_t>(dynamicSizes.size()) &&
"dynamicSizes size doesn't match dynamic sizes count in memref shape");
- Type indexType = getIndexType();
+ Type indexType = getIndexTypeMatchingMemRef(memRefType);
Value numElements = memRefType.getRank() == 0
? createIndexAttrConstant(rewriter, loc, indexType, 1)
: nullptr;
@@ -233,7 +254,7 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
// Field 3: Offset in aligned pointer.
- Type indexType = getIndexType();
+ Type indexType = getIndexTypeMatchingMemRef(memRefType);
memRefDescriptor.setOffset(
rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 9e03e2ffbacf83..c63be57dd850ff 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -174,6 +174,11 @@ Type LLVMTypeConverter::getIndexType() {
return IntegerType::get(&getContext(), getIndexTypeBitwidth());
}
+Type LLVMTypeConverter::getIndexTypeMatchingMemRef(MemRefType t) {
+ return options.memrefIndexTypeConverter ? options.memrefIndexTypeConverter(t)
+ : getIndexType();
+}
+
LLVM::LLVMPointerType
LLVMTypeConverter::getPointerType(Type elementType, unsigned int addressSpace) {
if (useOpaquePointers())
@@ -339,7 +344,7 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
}
auto ptrTy = getPointerType(elementType, *addressSpace);
- auto indexTy = getIndexType();
+ Type indexTy = getIndexTypeMatchingMemRef(type);
SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
auto rank = type.getRank();
@@ -358,7 +363,8 @@ unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
// Compute the descriptor size given that of its components indicated above.
unsigned space = *getMemRefAddressSpace(type);
return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
- (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
+ (1 + 2 * type.getRank()) *
+ layout.getTypeSize(getIndexTypeMatchingMemRef(type));
}
/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 55372b9c9c1248..1c2cd86c53706f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -160,7 +160,7 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
auto computeNumElements =
[&](MemRefType type, function_ref<Value()> getDynamicSize) -> Value {
// Compute number of elements.
- Type indexType = ConvertToLLVMPattern::getIndexType();
+ Type indexType = ConvertToLLVMPattern::getIndexTypeMatchingMemRef(type);
Value numElements =
type.isDynamicDim(0)
? getDynamicSize()
@@ -483,7 +483,8 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
// The size value that we have to extract can be obtained using GEPop with
// `dimOp.index() + 1` index argument.
Value idxPlusOne = rewriter.create<LLVM::AddOp>(
- loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
+ loc,
+ createIndexAttrConstant(rewriter, loc, adaptor.getIndex().getType(), 1),
adaptor.getIndex());
Value sizePtr = rewriter.create<LLVM::GEPOp>(
loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
@@ -510,7 +511,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
// Take advantage if index is constant.
MemRefType memRefType = cast<MemRefType>(operandType);
- Type indexType = getIndexType();
+ Type indexType = getIndexTypeMatchingMemRef(memRefType);
if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
int64_t i = *index;
if (i >= 0 && i < memRefType.getRank()) {
@@ -1360,7 +1361,7 @@ struct MemRefReshapeOpLowering
assert(targetMemRefType.getLayout().isIdentity() &&
"Identity layout map is a precondition of a valid reshape op");
- Type indexType = getIndexType();
+ Type indexType = getIndexTypeMatchingMemRef(targetMemRefType);
Value stride = nullptr;
int64_t targetRank = targetMemRefType.getRank();
for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
@@ -1455,7 +1456,8 @@ struct MemRefReshapeOpLowering
Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
- Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
+ Value oneIndex =
+ createIndexAttrConstant(rewriter, loc, resultRank.getType(), 1);
Value resultRankMinusOne =
rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
@@ -1708,7 +1710,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
- Type indexType = getIndexType();
+ auto indexType = targetMemRef.getIndexType();
// Field 3: The offset in the resulting type must be 0. This is
// because of the type change: an offset on srcType* may not be
// expressible as an offset on dstType*.
@@ -1865,7 +1867,7 @@ class ExtractStridedMetadataOpLowering
} // namespace
-static void populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns(
+void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
@@ -1881,6 +1883,7 @@ static void populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns(
GetGlobalMemrefOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
+ MemRefCopyOpLowering,
MemorySpaceCastOpLowering,
MemRefReinterpretCastOpLowering,
MemRefReshapeOpLowering,
@@ -1893,15 +1896,6 @@ static void populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns(
TransposeOpLowering,
ViewOpLowering>(converter);
// clang-format on
-}
-
-void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- // clang-format off
- patterns.add<
- MemRefCopyOpLowering>(converter);
- // clang-format on
-
auto allocLowering = converter.getOptions().allocLowering;
if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
patterns.add<AlignedAllocOpLowering, AlignedReallocOpLowering,
@@ -1909,9 +1903,6 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
patterns.add<AllocOpLowering, ReallocOpLowering, DeallocOpLowering>(
converter);
-
- populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns(converter,
- patterns);
}
namespace {
@@ -1940,12 +1931,6 @@ struct FinalizeMemRefToLLVMConversionPass
&dataLayoutAnalysis);
RewritePatternSet patterns(&getContext());
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
- if (isa<ModuleOp>(getOperation())) {
- populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
- } else {
- populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns(
- typeConverter, patterns);
- }
LLVMConversionTarget target(getContext());
target.addLegalOp<func::FuncOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index c9f378c181e36d..cec2830e10b932 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -35,6 +35,22 @@ using namespace mlir::gpu;
#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
+/// Return true if the given MemRefType has an address space that is a
+/// gpu::AddressSpaceAttr attribute with value 'workgroup`.
+bool gpu::GPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+ return isSharedMemoryAddressSpace(type.getMemorySpace());
+}
+
+/// Return true if the given Attribute has matches is a gpu::AddressSpaceAttr
+/// attribute with value 'workgroup`.
+bool gpu::GPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
+ if (!memorySpace)
+ return false;
+ if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// GPU Device Mapping Attributes
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list