[Mlir-commits] [mlir] 6ca1a09 - [mlir][gpu] Migrate hard-coded address space integers to an enum attribute (gpu::AddressSpaceAttr)
Christopher Bate
llvmlistbot at llvm.org
Fri Jan 13 10:00:15 PST 2023
Author: Christopher Bate
Date: 2023-01-13T11:00:10-07:00
New Revision: 6ca1a09f03e8e940f306bea73efa935e4ee38173
URL: https://github.com/llvm/llvm-project/commit/6ca1a09f03e8e940f306bea73efa935e4ee38173
DIFF: https://github.com/llvm/llvm-project/commit/6ca1a09f03e8e940f306bea73efa935e4ee38173.diff
LOG: [mlir][gpu] Migrate hard-coded address space integers to an enum attribute (gpu::AddressSpaceAttr)
This is a purely mechanical change that introduces an enum attribute in the GPU
dialect to represent the various memref memory spaces as opposed to the
hard-coded integer attributes that are currently used.
The following steps were taken to make the transition across the codebase:
1. Introduce a pass "gpu-lower-memory-space-attributes":
The pass updates all memref types that have a memory space attribute that is a
`gpu::AddressSpaceAttr`. These attributes are changed to `IntegerAttr`'s using a
mapping that is given by the caller. This pass is based on the
"map-memref-spirv-storage-class" pass and the common functions can probably
be refactored into a set of utilities under the MemRef dialect.
2. Update the verifiers of GPU/NVGPU dialect operations.
If a verifier currently checks the address space of an operand using
e.g.`getWorkspaceAddressSpace`, then it can continue to do so. However, the
checks are changed to only fail if the memory space is either missing or a wrong
value of type `gpu::AddressSpaceAttr`. Otherwise, it just assumes the address
space is correct because it was specifically lowered to something other than a
`gpu::AddressSpaceAttr`.
3. Update existing gpu-to-llvm conversion infrastructure.
In the existing gpu-to-X passes, we add a full conversion equivalent to
`gpu-lower-memory-space-attributes` just before doing the conversion to the
LLVMDialect. This is done because currently both the gpu-to-llvm passes
(rocdl,nvvm) run gpu-to-gpu rewrites within the pass, which introduce
`AddressSpaceAttr` memory space annotations. Therefore, I inserted the
memory space conversion between the gpu-to-gpu rewrites and the LLVM
conversion.
For more context see the below discourse discussion:
https://discourse.llvm.org/t/gpu-workgroup-shared-memory-address-space-is-hard-coded/
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D140644
Added:
mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
mlir/test/Dialect/GPU/lower-memory-space-attrs.mlir
Modified:
mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
mlir/lib/Dialect/GPU/CMakeLists.txt
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
mlir/test/Conversion/GPUCommon/memory-attrbution.mlir
mlir/test/Dialect/GPU/all-reduce-max.mlir
mlir/test/Dialect/GPU/all-reduce.mlir
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/GPU/promotion.mlir
mlir/test/Dialect/NVGPU/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 7b02ed7c707e..af8bcb88dd08 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -14,6 +14,7 @@
#define GPU_BASE
include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
@@ -46,11 +47,11 @@ def GPU_Dialect : Dialect {
/// Returns the numeric value used to identify the workgroup memory address
/// space.
- static unsigned getWorkgroupAddressSpace() { return 3; }
+ static AddressSpace getWorkgroupAddressSpace() { return AddressSpace::Workgroup; }
/// Returns the numeric value used to identify the private memory address
/// space.
- static unsigned getPrivateAddressSpace() { return 5; }
+ static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; }
}];
let dependentDialects = ["arith::ArithDialect"];
@@ -59,6 +60,37 @@ def GPU_Dialect : Dialect {
let useFoldAPI = kEmitFoldAdaptorFolder;
}
+//===----------------------------------------------------------------------===//
+// GPU Enums.
+//===----------------------------------------------------------------------===//
+
+class GPU_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
+ : I32EnumAttr<name, description, cases> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::gpu";
+}
+class GPU_I32EnumAttr<string mnemonic, GPU_I32Enum enumInfo> :
+ EnumAttr<GPU_Dialect, enumInfo, mnemonic> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def GPU_AddressSpaceGlobal : I32EnumAttrCase<"Global", 1, "global">;
+def GPU_AddressSpaceWorkgroup : I32EnumAttrCase<"Workgroup", 2, "workgroup">;
+def GPU_AddressSpacePrivate : I32EnumAttrCase<"Private", 3, "private">;
+def GPU_AddressSpaceEnum : GPU_I32Enum<
+ "AddressSpace", "GPU address space", [
+ GPU_AddressSpaceGlobal,
+ GPU_AddressSpaceWorkgroup,
+ GPU_AddressSpacePrivate
+ ]>;
+
+def GPU_AddressSpaceAttr :
+ GPU_I32EnumAttr<"address_space", GPU_AddressSpaceEnum>;
+
+//===----------------------------------------------------------------------===//
+// GPU Types.
+//===----------------------------------------------------------------------===//
+
def GPU_AsyncToken : DialectType<
GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::AsyncTokenType>()">, "async token type">,
BuildableType<"mlir::gpu::AsyncTokenType::get($_builder.getContext())">;
@@ -77,6 +109,10 @@ class MMAMatrixOf<list<Type> allowedTypes> :
"$_self.cast<::mlir::gpu::MMAMatrixType>().getElementType()",
"gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
+//===----------------------------------------------------------------------===//
+// GPU Interfaces.
+//===----------------------------------------------------------------------===//
+
def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
let description = [{
Interface for GPU operations that execute asynchronously on the device.
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 0d6e97fe230c..9ab830b0a067 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -23,6 +23,8 @@ class Module;
} // namespace llvm
namespace mlir {
+class TypeConverter;
+class ConversionTarget;
namespace func {
class FuncOp;
} // namespace func
@@ -58,6 +60,23 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
}
namespace gpu {
+/// A function that maps a MemorySpace enum to a target-specific integer value.
+using MemorySpaceMapping =
+ std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
+
+/// Populates type conversion rules for lowering memory space attributes to
+/// numeric values.
+void populateMemorySpaceAttributeTypeConversions(
+ TypeConverter &typeConverter, const MemorySpaceMapping &mapping);
+
+/// Populates patterns to lower memory space attributes to numeric values.
+void populateMemorySpaceLoweringPatterns(TypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+/// Populates legality rules for lowering memory space attriutes to numeric
+/// values.
+void populateLowerMemorySpaceOpLegality(ConversionTarget &target);
+
/// Returns the default annotation name for GPU binary blobs.
std::string getDefaultGpuBinaryAnnotation();
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index a144fa4127dd..fae2f0f37fc9 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -37,4 +37,23 @@ def GpuMapParallelLoopsPass
let dependentDialects = ["mlir::gpu::GPUDialect"];
}
+def GPULowerMemorySpaceAttributesPass
+ : Pass<"gpu-lower-memory-space-attributes"> {
+ let summary = "Assign numeric values to memref memory space symbolic placeholders";
+ let description = [{
+ Updates all memref types that have a memory space attribute
+ that is a `gpu::AddressSpaceAttr`. These attributes are
+ changed to `IntegerAttr`'s using a mapping that is given in the
+ options.
+ }];
+ let options = [
+ Option<"privateAddrSpace", "private", "unsigned", "5",
+ "private address space numeric value">,
+ Option<"workgroupAddrSpace", "workgroup", "unsigned", "3",
+ "workgroup address space numeric value">,
+ Option<"globalAddrSpace", "global", "unsigned", "1",
+ "global address space numeric value">
+ ];
+}
+
#endif // MLIR_DIALECT_GPU_PASSES
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index a39c41d4a530..862a8e100455 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -36,6 +36,13 @@ def ROCDL_Dialect : Dialect {
static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
}
+
+ /// The address space value that represents global memory.
+ static constexpr unsigned kGlobalMemoryAddressSpace = 1;
+ /// The address space value that represents shared memory.
+ static constexpr unsigned kSharedMemoryAddressSpace = 3;
+ /// The address space value that represents private memory.
+ static constexpr unsigned kPrivateMemoryAddressSpace = 5;
}];
}
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 4d90a0fcfa28..4d9db9ae39e5 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -37,6 +37,23 @@ def NVGPU_Dialect : Dialect {
let useDefaultTypePrinterParser = 1;
let useFoldAPI = kEmitFoldAdaptorFolder;
+
+ let extraClassDeclaration = [{
+ /// Return true if the given MemRefType has an integer address
+ /// space that matches the NVVM shared memory address space or
+ /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ static bool hasSharedMemoryAddressSpace(MemRefType type);
+
+ /// Defines the MemRef memory space attribute numeric value that indicates
+ /// a memref is located in global memory. This should correspond to the
+ /// value used in NVVM.
+ static constexpr unsigned kGlobaldMemoryAddressSpace = 1;
+
+ /// Defines the MemRef memory space attribute numeric value that indicates
+ /// a memref is located in shared memory. This should correspond to the
+ /// value used in NVVM.
+ static constexpr unsigned kSharedMemoryAddressSpace = 3;
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 668b4431147f..48c0cbf37988 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -38,7 +38,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
auto globalOp = rewriter.create<LLVM::GlobalOp>(
gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
LLVM::Linkage::Internal, name, /*value=*/Attribute(),
- /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace());
+ /*alignment=*/0, workgroupAddrSpace);
workgroupBuffers.push_back(globalOp);
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index d3ee0fd93442..55efb2230ad9 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -16,9 +16,10 @@ namespace mlir {
struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
GPUFuncOpLowering(LLVMTypeConverter &converter, unsigned allocaAddrSpace,
- StringAttr kernelAttributeName)
+ unsigned workgroupAddrSpace, StringAttr kernelAttributeName)
: ConvertOpToLLVMPattern<gpu::GPUFuncOp>(converter),
allocaAddrSpace(allocaAddrSpace),
+ workgroupAddrSpace(workgroupAddrSpace),
kernelAttributeName(kernelAttributeName) {}
LogicalResult
@@ -26,8 +27,10 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
ConversionPatternRewriter &rewriter) const override;
private:
- /// The address spcae to use for `alloca`s in private memory.
+ /// The address space to use for `alloca`s in private memory.
unsigned allocaAddrSpace;
+ /// The address space to use declaring workgroup memory.
+ unsigned workgroupAddrSpace;
/// The attribute name to use instead of `gpu.kernel`.
StringAttr kernelAttributeName;
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index fe578f7560ac..f3cf780306e7 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -175,31 +175,52 @@ struct LowerGpuOpsToNVVMOpsPass
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
- // MemRef conversion for GPU to NVVM lowering. The GPU dialect uses memory
- // space 5 for private memory attributions, but NVVM represents private
- // memory allocations as local `alloca`s in the default address space. This
- // converter drops the private memory space to support the use case above.
+ // Apply in-dialect lowering. In-dialect lowering will replace
+ // ops which need to be lowered further, which is not supported by a
+ // single conversion pass.
+ {
+ RewritePatternSet patterns(m.getContext());
+ populateGpuRewritePatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ // MemRef conversion for GPU to NVVM lowering.
+ {
+ RewritePatternSet patterns(m.getContext());
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type t) { return t; });
+ // NVVM uses alloca in the default address space to represent private
+ // memory allocations, so drop private annotations. NVVM uses address
+ // space 3 for shared memory. NVVM uses the default address space to
+ // represent global memory.
+ gpu::populateMemorySpaceAttributeTypeConversions(
+ typeConverter, [](gpu::AddressSpace space) -> unsigned {
+ switch (space) {
+ case gpu::AddressSpace::Global:
+ return static_cast<unsigned>(
+ NVVM::NVVMMemorySpace::kGlobalMemorySpace);
+ case gpu::AddressSpace::Workgroup:
+ return static_cast<unsigned>(
+ NVVM::NVVMMemorySpace::kSharedMemorySpace);
+ case gpu::AddressSpace::Private:
+ return 0;
+ }
+ });
+ gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns);
+ ConversionTarget target(getContext());
+ gpu::populateLowerMemorySpaceOpLegality(target);
+ if (failed(applyFullConversion(m, target, std::move(patterns))))
+ return signalPassFailure();
+ }
+
LLVMTypeConverter converter(m.getContext(), options);
- converter.addConversion([&](MemRefType type) -> std::optional<Type> {
- if (type.getMemorySpaceAsInt() !=
- gpu::GPUDialect::getPrivateAddressSpace())
- return std::nullopt;
- return converter.convertType(MemRefType::Builder(type).setMemorySpace(
- IntegerAttr::get(IntegerType::get(m.getContext(), 64), 0)));
- });
// Lowering for MMAMatrixType.
converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
return convertMMAToLLVMType(type);
});
- RewritePatternSet patterns(m.getContext());
RewritePatternSet llvmPatterns(m.getContext());
- // Apply in-dialect lowering first. In-dialect lowering will replace ops
- // which need to be lowered further, which is not supported by a single
- // conversion pass.
- populateGpuRewritePatterns(patterns);
- (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
-
arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
@@ -257,6 +278,8 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
// memory space and does not support `alloca`s with addrspace(5).
patterns.add<GPUFuncOpLowering>(
converter, /*allocaAddrSpace=*/0,
+ /*workgroupAddrSpace=*/
+ static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace),
StringAttr::get(&converter.getContext(),
NVVM::NVVMDialect::getKernelFuncAttrName()));
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 018288646a14..25571bd0ddb0 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -121,13 +121,41 @@ struct LowerGpuOpsToROCDLOpsPass
}
}
- LLVMTypeConverter converter(ctx, options);
+ // Apply in-dialect lowering. In-dialect lowering will replace
+ // ops which need to be lowered further, which is not supported by a
+ // single conversion pass.
+ {
+ RewritePatternSet patterns(ctx);
+ populateGpuRewritePatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
+ }
- RewritePatternSet patterns(ctx);
- RewritePatternSet llvmPatterns(ctx);
+ // Apply memory space lowering. The target uses 3 for workgroup memory and 5
+ // for private memory.
+ {
+ RewritePatternSet patterns(ctx);
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type t) { return t; });
+ gpu::populateMemorySpaceAttributeTypeConversions(
+ typeConverter, [](gpu::AddressSpace space) {
+ switch (space) {
+ case gpu::AddressSpace::Global:
+ return 1;
+ case gpu::AddressSpace::Workgroup:
+ return 3;
+ case gpu::AddressSpace::Private:
+ return 5;
+ }
+ });
+ ConversionTarget target(getContext());
+ gpu::populateLowerMemorySpaceOpLegality(target);
+ gpu::populateMemorySpaceLoweringPatterns(typeConverter, patterns);
+ if (failed(applyFullConversion(m, target, std::move(patterns))))
+ return signalPassFailure();
+ }
- populateGpuRewritePatterns(patterns);
- (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
+ LLVMTypeConverter converter(ctx, options);
+ RewritePatternSet llvmPatterns(ctx);
mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
@@ -208,7 +236,9 @@ void mlir::populateGpuToROCDLConversionPatterns(
ROCDL::GridDimYOp, ROCDL::GridDimZOp>,
GPUReturnOpLowering>(converter);
patterns.add<GPUFuncOpLowering>(
- converter, /*allocaAddrSpace=*/5,
+ converter,
+ /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
+ /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
StringAttr::get(&converter.getContext(),
ROCDL::ROCDLDialect::getKernelFuncAttrName()));
if (Runtime::HIP == runtime) {
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index 884b5ea316bd..3f92c6f8e55a 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -222,7 +222,9 @@ static bool isLegalAttr(Attribute attr) {
static bool isLegalOp(Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
- llvm::all_of(funcOp.getResultTypes(), isLegalType);
+ llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
+ llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
+ isLegalType);
}
auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 94f3ab505f23..a38695878c10 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -52,6 +52,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/SerializeToBlob.cpp
Transforms/SerializeToCubin.cpp
Transforms/SerializeToHsaco.cpp
+ Transforms/LowerMemorySpaceAttributes.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 031e0d92153e..f9018275f6a4 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1015,16 +1015,22 @@ LogicalResult GPUFuncOp::verifyType() {
static LogicalResult verifyAttributions(Operation *op,
ArrayRef<BlockArgument> attributions,
- unsigned memorySpace) {
+ gpu::AddressSpace memorySpace) {
for (Value v : attributions) {
auto type = v.getType().dyn_cast<MemRefType>();
if (!type)
return op->emitOpError() << "expected memref type in attribution";
- if (type.getMemorySpaceAsInt() != memorySpace) {
+ // We can only verify the address space if it hasn't already been lowered
+ // from the AddressSpaceAttr to a target-specific numeric value.
+ auto addressSpace =
+ type.getMemorySpace().dyn_cast_or_null<gpu::AddressSpaceAttr>();
+ if (!addressSpace)
+ continue;
+ if (addressSpace.getValue() != memorySpace)
return op->emitOpError()
- << "expected memory space " << memorySpace << " in attribution";
- }
+ << "expected memory space " << stringifyAddressSpace(memorySpace)
+ << " in attribution";
}
return success();
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 894573b919ef..0a584a7920e0 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -158,8 +158,8 @@ struct GpuAllReduceRewriter {
/// Adds type to funcOp's workgroup attributions.
Value createWorkgroupBuffer() {
// TODO: Pick a proper location for the attribution.
- int workgroupMemoryAddressSpace =
- gpu::GPUDialect::getWorkgroupAddressSpace();
+ auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
+ funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
workgroupMemoryAddressSpace);
return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
@@ -404,7 +404,7 @@ struct GpuAllReduceConversion : public RewritePattern {
return WalkResult::advance();
};
- if (funcOp.walk(callback).wasInterrupted())
+ if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
return rewriter.notifyMatchFailure(
op, "Non uniform reductions are not supported yet.");
diff --git a/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp b/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
new file mode 100644
index 000000000000..bb5fb06f7fb9
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp
@@ -0,0 +1,182 @@
+//===- LowerMemorySpaceAttributes.cpp ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// Implementation of a pass that rewrites the IR so that uses of
+/// `gpu::AddressSpaceAttr` in memref memory space annotations are replaced
+/// with caller-specified numeric values.
+///
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_GPULOWERMEMORYSPACEATTRIBUTESPASS
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+//===----------------------------------------------------------------------===//
+// Conversion Target
+//===----------------------------------------------------------------------===//
+
+/// Returns true if the given `type` is considered as legal during memory space
+/// attribute lowering.
+static bool isLegalType(Type type) {
+ if (auto memRefType = type.dyn_cast<BaseMemRefType>()) {
+ return !memRefType.getMemorySpace()
+ .isa_and_nonnull<gpu::AddressSpaceAttr>();
+ }
+ return true;
+}
+
+/// Returns true if the given `attr` is considered legal during memory space
+/// attribute lowering.
+static bool isLegalAttr(Attribute attr) {
+ if (auto typeAttr = attr.dyn_cast<TypeAttr>())
+ return isLegalType(typeAttr.getValue());
+ return true;
+}
+
+/// Returns true if the given `op` is legal during memory space attribute
+/// lowering.
+static bool isLegalOp(Operation *op) {
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
+ llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
+ llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
+ isLegalType);
+ }
+
+ auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
+ return attr.getValue();
+ });
+
+ return llvm::all_of(op->getOperandTypes(), isLegalType) &&
+ llvm::all_of(op->getResultTypes(), isLegalType) &&
+ llvm::all_of(attrs, isLegalAttr);
+}
+
+void gpu::populateLowerMemorySpaceOpLegality(ConversionTarget &target) {
+ target.markUnknownOpDynamicallyLegal(isLegalOp);
+}
+
+//===----------------------------------------------------------------------===//
+// Type Converter
+//===----------------------------------------------------------------------===//
+
+IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
+ return IntegerAttr::get(IntegerType::get(ctx, 64), space);
+}
+
+void mlir::gpu::populateMemorySpaceAttributeTypeConversions(
+ TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
+ typeConverter.addConversion([mapping](Type type) -> Optional<Type> {
+ auto subElementType = type.dyn_cast_or_null<SubElementTypeInterface>();
+ if (!subElementType)
+ return type;
+ Type newType = subElementType.replaceSubElements(
+ [mapping](Attribute attr) -> std::optional<Attribute> {
+ auto memorySpaceAttr = attr.dyn_cast_or_null<gpu::AddressSpaceAttr>();
+ if (!memorySpaceAttr)
+ return std::nullopt;
+ auto newValue = wrapNumericMemorySpace(
+ attr.getContext(), mapping(memorySpaceAttr.getValue()));
+ return newValue;
+ });
+ return newType;
+ });
+}
+
+namespace {
+
+/// Converts any op that has operands/results/attributes with numeric MemRef
+/// memory spaces.
+struct LowerMemRefAddressSpacePattern final : public ConversionPattern {
+ LowerMemRefAddressSpacePattern(MLIRContext *context, TypeConverter &converter)
+ : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<NamedAttribute> newAttrs;
+ newAttrs.reserve(op->getAttrs().size());
+ for (auto attr : op->getAttrs()) {
+ if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) {
+ auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
+ newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
+ } else {
+ newAttrs.push_back(attr);
+ }
+ }
+
+ SmallVector<Type> newResults;
+ (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
+
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, newAttrs, op->getSuccessors());
+
+ for (Region ®ion : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+ TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+ (void)getTypeConverter()->convertSignatureArgs(
+ newRegion->getArgumentTypes(), result);
+ rewriter.applySignatureConversion(newRegion, result);
+ }
+
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+} // namespace
+
+void mlir::gpu::populateMemorySpaceLoweringPatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns) {
+ patterns.add<LowerMemRefAddressSpacePattern>(patterns.getContext(),
+ typeConverter);
+}
+
+namespace {
+class LowerMemorySpaceAttributesPass
+ : public mlir::impl::GPULowerMemorySpaceAttributesPassBase<
+ LowerMemorySpaceAttributesPass> {
+public:
+ using Base::Base;
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ Operation *op = getOperation();
+
+ ConversionTarget target(getContext());
+ populateLowerMemorySpaceOpLegality(target);
+
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type t) { return t; });
+ populateMemorySpaceAttributeTypeConversions(
+ typeConverter, [this](AddressSpace space) -> unsigned {
+ switch (space) {
+ case AddressSpace::Global:
+ return globalAddrSpace;
+ case AddressSpace::Workgroup:
+ return workgroupAddrSpace;
+ case AddressSpace::Private:
+ return privateAddrSpace;
+ }
+ });
+ RewritePatternSet patterns(context);
+ populateMemorySpaceLoweringPatterns(typeConverter, patterns);
+ if (failed(applyFullConversion(op, target, std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
index 49b77009caa4..5672b02b0226 100644
--- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
@@ -147,9 +147,11 @@ void mlir::promoteToWorkgroupMemory(GPUFuncOp op, unsigned arg) {
assert(type && type.hasStaticShape() && "can only promote memrefs");
// Get the type of the buffer in the workgroup memory.
- int workgroupMemoryAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
- auto bufferType = MemRefType::get(type.getShape(), type.getElementType(), {},
- workgroupMemoryAddressSpace);
+ auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
+ op->getContext(), gpu::AddressSpace::Workgroup);
+ auto bufferType = MemRefType::get(type.getShape(), type.getElementType(),
+ MemRefLayoutAttrInterface{},
+ Attribute(workgroupMemoryAddressSpace));
Value attribution = op.addWorkgroupAttribution(bufferType, value.getLoc());
// Replace the uses first since only the original uses are currently present.
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 24f70cb986e2..99623cdc2b34 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -34,6 +34,17 @@ void nvgpu::NVGPUDialect::initialize() {
>();
}
+bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+ Attribute memorySpace = type.getMemorySpace();
+ if (!memorySpace)
+ return false;
+ if (auto intAttr = memorySpace.dyn_cast<IntegerAttr>())
+ return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
+ if (auto gpuAttr = memorySpace.dyn_cast<gpu::AddressSpaceAttr>())
+ return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// NVGPU_DeviceAsyncCopyOp
//===----------------------------------------------------------------------===//
@@ -52,14 +63,17 @@ static bool isLastMemrefDimUnitStride(MemRefType type) {
LogicalResult DeviceAsyncCopyOp::verify() {
auto srcMemref = getSrc().getType().cast<MemRefType>();
auto dstMemref = getDst().getType().cast<MemRefType>();
- unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
+
if (!isLastMemrefDimUnitStride(srcMemref))
return emitError("source memref most minor dim must have unit stride");
if (!isLastMemrefDimUnitStride(dstMemref))
return emitError("destination memref most minor dim must have unit stride");
- if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace)
- return emitError("destination memref must have memory space ")
- << workgroupAddressSpace;
+ if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
+ return emitError()
+ << "destination memref must have a memory space attribute of "
+ "IntegerAttr("
+ << NVGPUDialect::kSharedMemoryAddressSpace
+ << ") or gpu::AddressSpaceAttr(Workgroup)";
if (dstMemref.getElementType() != srcMemref.getElementType())
return emitError("source and destination must have the same element type");
if (size_t(srcMemref.getRank()) != getSrcIndices().size())
@@ -248,17 +262,16 @@ LogicalResult LdMatrixOp::verify() {
// transpose elements in vector registers at 16b granularity when true
bool isTranspose = getTranspose();
- // address space id for shared memory
- unsigned smemAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
-
//
// verification
//
- if (!(srcMemref.getMemorySpaceAsInt() == smemAddressSpace))
+ if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
return emitError()
- << "expected nvgpu.ldmatrix srcMemref must have memory space "
- << smemAddressSpace;
+ << "expected nvgpu.ldmatrix srcMemref must have a memory space "
+ "attribute of IntegerAttr("
+ << NVGPUDialect::kSharedMemoryAddressSpace
+ << ") or gpu::AddressSpaceAttr(Workgroup)";
if (elementBitWidth > 32)
return emitError() << "nvgpu.ldmatrix works for 32b or lower";
if (isTranspose && !(elementBitWidth == 16))
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
index 5d511e83cf16..07e9ae9f8650 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
@@ -181,8 +181,7 @@ mlir::LogicalResult
mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue) {
auto memRefType = memrefValue.getType().dyn_cast<MemRefType>();
- if (!memRefType || memRefType.getMemorySpaceAsInt() !=
- gpu::GPUDialect::getWorkgroupAddressSpace())
+ if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
return failure();
// Abort if the given value has any sub-views; we do not do any alias
@@ -258,11 +257,7 @@ class OptimizeSharedMemoryPass
Operation *op = getOperation();
SmallVector<memref::AllocOp> shmAllocOps;
op->walk([&](memref::AllocOp allocOp) {
- if (allocOp.getMemref()
- .getType()
- .cast<MemRefType>()
- .getMemorySpaceAsInt() !=
- gpu::GPUDialect::getWorkgroupAddressSpace())
+ if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
return;
shmAllocOps.push_back(allocOp);
});
diff --git a/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir b/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir
index 9c94e9c14dea..7d563a545d45 100644
--- a/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir
+++ b/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir
@@ -3,7 +3,7 @@
gpu.module @kernel {
// NVVM-LABEL: llvm.func @private
- gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, 5>) {
+ gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, #gpu.address_space<private>>) {
// Allocate private memory inside the function.
// NVVM: %[[size:.*]] = llvm.mlir.constant(4 : i64) : i64
// NVVM: %[[raw:.*]] = llvm.alloca %[[size]] x f32 : (i64) -> !llvm.ptr<f32>
@@ -42,7 +42,7 @@ gpu.module @kernel {
// ROCDL: llvm.getelementptr
// ROCDL: llvm.store
%c0 = arith.constant 0 : index
- memref.store %arg0, %arg1[%c0] : memref<4xf32, 5>
+ memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<private>>
"terminator"() : () -> ()
}
@@ -65,7 +65,7 @@ gpu.module @kernel {
// ROCDL-LABEL: llvm.func @workgroup
// ROCDL-SAME: {
- gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, 3>) {
+ gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, #gpu.address_space<workgroup>>) {
// Get the address of the first element in the global array.
// NVVM: %[[addr:.*]] = llvm.mlir.addressof @[[$buffer]] : !llvm.ptr<array<4 x f32>, 3>
// NVVM: %[[raw:.*]] = llvm.getelementptr %[[addr]][0, 0]
@@ -106,7 +106,7 @@ gpu.module @kernel {
// ROCDL: llvm.getelementptr
// ROCDL: llvm.store
%c0 = arith.constant 0 : index
- memref.store %arg0, %arg1[%c0] : memref<4xf32, 3>
+ memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<workgroup>>
"terminator"() : () -> ()
}
@@ -126,7 +126,7 @@ gpu.module @kernel {
// NVVM-LABEL: llvm.func @workgroup3d
// ROCDL-LABEL: llvm.func @workgroup3d
- gpu.func @workgroup3d(%arg0: f32) workgroup(%arg1: memref<4x2x6xf32, 3>) {
+ gpu.func @workgroup3d(%arg0: f32) workgroup(%arg1: memref<4x2x6xf32, #gpu.address_space<workgroup>>) {
// Get the address of the first element in the global array.
// NVVM: %[[addr:.*]] = llvm.mlir.addressof @[[$buffer]] : !llvm.ptr<array<48 x f32>, 3>
// NVVM: %[[raw:.*]] = llvm.getelementptr %[[addr]][0, 0]
@@ -174,7 +174,7 @@ gpu.module @kernel {
// ROCDL: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2]
%c0 = arith.constant 0 : index
- memref.store %arg0, %arg1[%c0,%c0,%c0] : memref<4x2x6xf32, 3>
+ memref.store %arg0, %arg1[%c0,%c0,%c0] : memref<4x2x6xf32, #gpu.address_space<workgroup>>
"terminator"() : () -> ()
}
}
@@ -196,8 +196,8 @@ gpu.module @kernel {
// NVVM-LABEL: llvm.func @multiple
// ROCDL-LABEL: llvm.func @multiple
gpu.func @multiple(%arg0: f32)
- workgroup(%arg1: memref<1xf32, 3>, %arg2: memref<2xf32, 3>)
- private(%arg3: memref<3xf32, 5>, %arg4: memref<4xf32, 5>) {
+ workgroup(%arg1: memref<1xf32, #gpu.address_space<workgroup>>, %arg2: memref<2xf32, #gpu.address_space<workgroup>>)
+ private(%arg3: memref<3xf32, #gpu.address_space<private>>, %arg4: memref<4xf32, #gpu.address_space<private>>) {
// Workgroup buffers.
// NVVM: llvm.mlir.addressof @[[$buffer1]]
@@ -218,10 +218,10 @@ gpu.module @kernel {
// ROCDL: llvm.alloca %[[c4]] x f32 : (i64) -> !llvm.ptr<f32, 5>
%c0 = arith.constant 0 : index
- memref.store %arg0, %arg1[%c0] : memref<1xf32, 3>
- memref.store %arg0, %arg2[%c0] : memref<2xf32, 3>
- memref.store %arg0, %arg3[%c0] : memref<3xf32, 5>
- memref.store %arg0, %arg4[%c0] : memref<4xf32, 5>
+ memref.store %arg0, %arg1[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+ memref.store %arg0, %arg2[%c0] : memref<2xf32, #gpu.address_space<workgroup>>
+ memref.store %arg0, %arg3[%c0] : memref<3xf32, #gpu.address_space<private>>
+ memref.store %arg0, %arg4[%c0] : memref<4xf32, #gpu.address_space<private>>
"terminator"() : () -> ()
}
}
diff --git a/mlir/test/Dialect/GPU/all-reduce-max.mlir b/mlir/test/Dialect/GPU/all-reduce-max.mlir
index d39b961c0085..a71544ba0e98 100644
--- a/mlir/test/Dialect/GPU/all-reduce-max.mlir
+++ b/mlir/test/Dialect/GPU/all-reduce-max.mlir
@@ -5,7 +5,7 @@
gpu.module @kernels {
// CHECK-LABEL: gpu.func @kernel(
- // CHECK-SAME: [[VAL_0:%.*]]: f32) workgroup([[VAL_1:%.*]] : memref<32xf32, 3>) kernel {
+ // CHECK-SAME: [[VAL_0:%.*]]: f32) workgroup([[VAL_1:%.*]] : memref<32xf32, #gpu.address_space<workgroup>>) kernel {
gpu.func @kernel(%arg0 : f32) kernel {
// CHECK-DAG: [[VAL_2:%.*]] = arith.constant 31 : i32
// CHECK-DAG: [[VAL_3:%.*]] = arith.constant 0 : i32
@@ -109,7 +109,7 @@ gpu.module @kernels {
// CHECK: ^bb19:
// CHECK: [[VAL_80:%.*]] = arith.divsi [[VAL_27]], [[VAL_5]] : i32
// CHECK: [[VAL_81:%.*]] = arith.index_cast [[VAL_80]] : i32 to index
- // CHECK: store [[VAL_79]], [[VAL_1]]{{\[}}[[VAL_81]]] : memref<32xf32, 3>
+ // CHECK: store [[VAL_79]], [[VAL_1]]{{\[}}[[VAL_81]]] : memref<32xf32, #gpu.address_space<workgroup>>
// CHECK: cf.br ^bb21
// CHECK: ^bb20:
// CHECK: cf.br ^bb21
@@ -121,7 +121,7 @@ gpu.module @kernels {
// CHECK: cf.cond_br [[VAL_84]], ^bb22, ^bb41
// CHECK: ^bb22:
// CHECK: [[VAL_85:%.*]] = arith.index_cast [[VAL_27]] : i32 to index
- // CHECK: [[VAL_86:%.*]] = memref.load [[VAL_1]]{{\[}}[[VAL_85]]] : memref<32xf32, 3>
+ // CHECK: [[VAL_86:%.*]] = memref.load [[VAL_1]]{{\[}}[[VAL_85]]] : memref<32xf32, #gpu.address_space<workgroup>>
// CHECK: [[VAL_87:%.*]] = arith.cmpi slt, [[VAL_83]], [[VAL_5]] : i32
// CHECK: cf.cond_br [[VAL_87]], ^bb23, ^bb39
// CHECK: ^bb23:
@@ -189,7 +189,7 @@ gpu.module @kernels {
// CHECK: [[VAL_132:%.*]] = arith.select [[VAL_131]], [[VAL_128]], [[VAL_129]] : f32
// CHECK: cf.br ^bb40([[VAL_132]] : f32)
// CHECK: ^bb40([[VAL_133:%.*]]: f32):
- // CHECK: store [[VAL_133]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, 3>
+ // CHECK: store [[VAL_133]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, #gpu.address_space<workgroup>>
// CHECK: cf.br ^bb42
// CHECK: ^bb41:
// CHECK: cf.br ^bb42
diff --git a/mlir/test/Dialect/GPU/all-reduce.mlir b/mlir/test/Dialect/GPU/all-reduce.mlir
index 67d83357e0ea..2a24e1de3bf3 100644
--- a/mlir/test/Dialect/GPU/all-reduce.mlir
+++ b/mlir/test/Dialect/GPU/all-reduce.mlir
@@ -5,7 +5,7 @@
gpu.module @kernels {
// CHECK-LABEL: gpu.func @kernel(
- // CHECK-SAME: [[VAL_0:%.*]]: f32) workgroup([[VAL_1:%.*]] : memref<32xf32, 3>) kernel {
+ // CHECK-SAME: [[VAL_0:%.*]]: f32) workgroup([[VAL_1:%.*]] : memref<32xf32, #gpu.address_space<workgroup>>) kernel {
gpu.func @kernel(%arg0 : f32) kernel {
// CHECK-DAG: [[VAL_2:%.*]] = arith.constant 31 : i32
// CHECK-DAG: [[VAL_3:%.*]] = arith.constant 0 : i32
@@ -99,7 +99,7 @@ gpu.module @kernels {
// CHECK: ^bb19:
// CHECK: [[VAL_70:%.*]] = arith.divsi [[VAL_27]], [[VAL_5]] : i32
// CHECK: [[VAL_71:%.*]] = arith.index_cast [[VAL_70]] : i32 to index
- // CHECK: store [[VAL_69]], [[VAL_1]]{{\[}}[[VAL_71]]] : memref<32xf32, 3>
+ // CHECK: store [[VAL_69]], [[VAL_1]]{{\[}}[[VAL_71]]] : memref<32xf32, #gpu.address_space<workgroup>>
// CHECK: cf.br ^bb21
// CHECK: ^bb20:
// CHECK: cf.br ^bb21
@@ -111,7 +111,7 @@ gpu.module @kernels {
// CHECK: cf.cond_br [[VAL_74]], ^bb22, ^bb41
// CHECK: ^bb22:
// CHECK: [[VAL_75:%.*]] = arith.index_cast [[VAL_27]] : i32 to index
- // CHECK: [[VAL_76:%.*]] = memref.load [[VAL_1]]{{\[}}[[VAL_75]]] : memref<32xf32, 3>
+ // CHECK: [[VAL_76:%.*]] = memref.load [[VAL_1]]{{\[}}[[VAL_75]]] : memref<32xf32, #gpu.address_space<workgroup>>
// CHECK: [[VAL_77:%.*]] = arith.cmpi slt, [[VAL_73]], [[VAL_5]] : i32
// CHECK: cf.cond_br [[VAL_77]], ^bb23, ^bb39
// CHECK: ^bb23:
@@ -169,7 +169,7 @@ gpu.module @kernels {
// CHECK: [[VAL_112:%.*]] = arith.addf [[VAL_109]], [[VAL_110]] : f32
// CHECK: cf.br ^bb40([[VAL_112]] : f32)
// CHECK: ^bb40([[VAL_113:%.*]]: f32):
- // CHECK: store [[VAL_113]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, 3>
+ // CHECK: store [[VAL_113]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, #gpu.address_space<workgroup>>
// CHECK: cf.br ^bb42
// CHECK: ^bb41:
// CHECK: cf.br ^bb42
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 76a14d353bc4..a139f4c3d854 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -350,7 +350,7 @@ module {
module {
gpu.module @gpu_funcs {
- // expected-error @+1 {{expected memref type in attribution}}
+ // expected-error @below {{'gpu.func' op expected memref type in attribution}}
gpu.func @kernel() workgroup(%0: i32) {
gpu.return
}
@@ -361,8 +361,8 @@ module {
module {
gpu.module @gpu_funcs {
- // expected-error @+1 {{expected memory space 3 in attribution}}
- gpu.func @kernel() workgroup(%0: memref<4xf32>) {
+ // expected-error @below {{'gpu.func' op expected memory space workgroup in attribution}}
+ gpu.func @kernel() workgroup(%0: memref<4xf32, #gpu.address_space<private>>) {
gpu.return
}
}
@@ -372,19 +372,8 @@ module {
module {
gpu.module @gpu_funcs {
- // expected-error @+1 {{expected memory space 5 in attribution}}
- gpu.func @kernel() private(%0: memref<4xf32>) {
- gpu.return
- }
- }
-}
-
-// -----
-
-module {
- gpu.module @gpu_funcs {
- // expected-error @+1 {{expected memory space 5 in attribution}}
- gpu.func @kernel() private(%0: memref<4xf32>) {
+ // expected-error @below {{'gpu.func' op expected memory space private in attribution}}
+ gpu.func @kernel() private(%0: memref<4xf32, #gpu.address_space<workgroup>>) {
gpu.return
}
}
diff --git a/mlir/test/Dialect/GPU/lower-memory-space-attrs.mlir b/mlir/test/Dialect/GPU/lower-memory-space-attrs.mlir
new file mode 100644
index 000000000000..9b4f1dee597b
--- /dev/null
+++ b/mlir/test/Dialect/GPU/lower-memory-space-attrs.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt %s -split-input-file -gpu-lower-memory-space-attributes | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -gpu-lower-memory-space-attributes="private=0 global=0" | FileCheck %s --check-prefix=CUDA
+
+gpu.module @kernel {
+ gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, #gpu.address_space<private>>) {
+ %c0 = arith.constant 0 : index
+ memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<private>>
+ gpu.return
+ }
+}
+
+// CHECK: gpu.func @private
+// CHECK-SAME: private(%{{.+}}: memref<4xf32, 5>)
+// CHECK: memref.store
+// CHECK-SAME: : memref<4xf32, 5>
+
+// CUDA: gpu.func @private
+// CUDA-SAME: private(%{{.+}}: memref<4xf32>)
+// CUDA: memref.store
+// CUDA-SAME: : memref<4xf32>
+
+// -----
+
+gpu.module @kernel {
+ gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, #gpu.address_space<workgroup>>) {
+ %c0 = arith.constant 0 : index
+ memref.store %arg0, %arg1[%c0] : memref<4xf32, #gpu.address_space<workgroup>>
+ gpu.return
+ }
+}
+
+// CHECK: gpu.func @workgroup
+// CHECK-SAME: workgroup(%{{.+}}: memref<4xf32, 3>)
+// CHECK: memref.store
+// CHECK-SAME: : memref<4xf32, 3>
+
+// -----
+
+gpu.module @kernel {
+ gpu.func @nested_memref(%arg0: memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>) {
+ %c0 = arith.constant 0 : index
+ memref.load %arg0[%c0] : memref<4xmemref<4xf32, #gpu.address_space<global>>, #gpu.address_space<global>>
+ gpu.return
+ }
+}
+
+// CHECK: gpu.func @nested_memref
+// CHECK-SAME: (%{{.+}}: memref<4xmemref<4xf32, 1>, 1>)
+// CHECK: memref.load
+// CHECK-SAME: : memref<4xmemref<4xf32, 1>, 1>
+
+// CUDA: gpu.func @nested_memref
+// CUDA-SAME: (%{{.+}}: memref<4xmemref<4xf32>>)
+// CUDA: memref.load
+// CUDA-SAME: : memref<4xmemref<4xf32>>
diff --git a/mlir/test/Dialect/GPU/promotion.mlir b/mlir/test/Dialect/GPU/promotion.mlir
index db33f5cf4b5b..b4668b567889 100644
--- a/mlir/test/Dialect/GPU/promotion.mlir
+++ b/mlir/test/Dialect/GPU/promotion.mlir
@@ -5,7 +5,7 @@ gpu.module @foo {
// Verify that the attribution was indeed introduced
// CHECK-LABEL: @memref3d
// CHECK-SAME: (%[[arg:.*]]: memref<5x4xf32>
- // CHECK-SAME: workgroup(%[[promoted:.*]] : memref<5x4xf32, 3>)
+ // CHECK-SAME: workgroup(%[[promoted:.*]] : memref<5x4xf32, #gpu.address_space<workgroup>>)
gpu.func @memref3d(%arg0: memref<5x4xf32> {gpu.test_promote_workgroup}) kernel {
// Verify that loop bounds are emitted, the order does not matter.
// CHECK-DAG: %[[c1:.*]] = arith.constant 1
@@ -30,7 +30,7 @@ gpu.module @foo {
// CHECK: store %[[v]], %[[promoted]][%[[i1]], %[[i2]]]
// Verify that the use has been rewritten.
- // CHECK: "use"(%[[promoted]]) : (memref<5x4xf32, 3>)
+ // CHECK: "use"(%[[promoted]]) : (memref<5x4xf32, #gpu.address_space<workgroup>>)
"use"(%arg0) : (memref<5x4xf32>) -> ()
@@ -55,7 +55,7 @@ gpu.module @foo {
// Verify that the attribution was indeed introduced
// CHECK-LABEL: @memref5d
// CHECK-SAME: (%[[arg:.*]]: memref<8x7x6x5x4xf32>
- // CHECK-SAME: workgroup(%[[promoted:.*]] : memref<8x7x6x5x4xf32, 3>)
+ // CHECK-SAME: workgroup(%[[promoted:.*]] : memref<8x7x6x5x4xf32, #gpu.address_space<workgroup>>)
gpu.func @memref5d(%arg0: memref<8x7x6x5x4xf32> {gpu.test_promote_workgroup}) kernel {
// Verify that loop bounds are emitted, the order does not matter.
// CHECK-DAG: %[[c0:.*]] = arith.constant 0
@@ -84,7 +84,7 @@ gpu.module @foo {
// CHECK: store %[[v]], %[[promoted]][%[[i0]], %[[i1]], %[[i2]], %[[i3]], %[[i4]]]
// Verify that the use has been rewritten.
- // CHECK: "use"(%[[promoted]]) : (memref<8x7x6x5x4xf32, 3>)
+ // CHECK: "use"(%[[promoted]]) : (memref<8x7x6x5x4xf32, #gpu.address_space<workgroup>>)
"use"(%arg0) : (memref<8x7x6x5x4xf32>) -> ()
// Verify that loop loops for the copy are emitted.
@@ -108,11 +108,11 @@ gpu.module @foo {
// Check that attribution insertion works fine.
// CHECK-LABEL: @insert
// CHECK-SAME: (%{{.*}}: memref<4xf32>
- // CHECK-SAME: workgroup(%{{.*}}: memref<1x1xf64, 3>
- // CHECK-SAME: %[[wg2:.*]] : memref<4xf32, 3>)
+ // CHECK-SAME: workgroup(%{{.*}}: memref<1x1xf64, #gpu.address_space<workgroup>>
+ // CHECK-SAME: %[[wg2:.*]] : memref<4xf32, #gpu.address_space<workgroup>>)
// CHECK-SAME: private(%{{.*}}: memref<1x1xi64, 5>)
gpu.func @insert(%arg0: memref<4xf32> {gpu.test_promote_workgroup})
- workgroup(%arg1: memref<1x1xf64, 3>)
+ workgroup(%arg1: memref<1x1xf64, #gpu.address_space<workgroup>>)
private(%arg2: memref<1x1xi64, 5>)
kernel {
// CHECK: "use"(%[[wg2]])
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index 1d1205aca96e..57f91b334359 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -2,7 +2,7 @@
func.func @ldmatrix_address_space_f16_x4(%arg0: memref<128x128xf16, 2>) -> vector<4x1xf16> {
%c0 = arith.constant 0 : index
- // expected-error @+1 {{expected nvgpu.ldmatrix srcMemref must have memory space 3}}
+ // expected-error @below {{expected nvgpu.ldmatrix srcMemref must have a memory space attribute of IntegerAttr(3) or gpu::AddressSpaceAttr(Workgroup)}}
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 2> -> vector<4x1xf16>
return %a : vector<4x1xf16>
}
@@ -126,7 +126,7 @@ func.func @m16n8k32_int32_datatype(%arg0: vector<4x4xi32>, %arg1: vector<2x4xi8>
// -----
func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () {
- // expected-error @+1 {{destination memref must have memory space 3}}
+ // expected-error @below {{destination memref must have a memory space attribute of IntegerAttr(3) or gpu::AddressSpaceAttr(Workgroup)}}
nvgpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xf32>
return
}
More information about the Mlir-commits
mailing list