[Mlir-commits] [mlir] [mlir][spirv][gpu] Default to KHR coop matrix. Clean up type conversion. (PR #67485)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 26 13:34:05 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
<details>
<summary>Changes</summary>
- Now that the KHR coop matrix implementation is robust, switch the gpu conversion pass to default to it.
- Use a populate function for MMA to coop matrix type conversions. This makes the API surface area smaller.
---
Full diff: https://github.com/llvm/llvm-project/pull/67485.diff
4 Files Affected:
- (modified) mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h (+5-13)
- (modified) mlir/include/mlir/Conversion/Passes.td (+1-1)
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp (+2-8)
- (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+39-28)
``````````diff
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index c258513ed4878ea..cd650345f1daa2d 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -20,10 +20,6 @@
namespace mlir {
class SPIRVTypeConverter;
-namespace gpu {
-class MMAMatrixType;
-} // namespace gpu
-
/// Appends to a pattern list additional patterns for translating GPU Ops to
/// SPIR-V ops. For a gpu.func to be converted, it should have a
/// spirv.entry_point_abi attribute.
@@ -40,15 +36,11 @@ void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
-/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType
-/// `type`.
-spirv::CooperativeMatrixType
-convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type);
-
-/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
-/// `type`.
-spirv::CooperativeMatrixNVType
-convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type);
+/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix type
+/// conversion to the type converter. Defaults to KHR cooperative matrix types.
+/// When `useNVTypes` is `true`, uses the NV cooperative matrix types.
+void populateMMAToSPIRVCoopMatrixTypeConversion(
+ SPIRVTypeConverter &typeConverter, bool useNVTypes = false);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 11008baa0160efe..5e0f976b18f7da5 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -569,7 +569,7 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
"bool", /*default=*/"false",
"Use 64-bit integers to convert index types">,
Option<"useCoopMatrixNV", "use-coop-matrix-nv",
- "bool", /*default=*/"true",
+ "bool", /*default=*/"false",
"Use the NV cooperative matrix extension insted of the KHR extension"
" to lower GPU WMMA ops">,
];
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 5b05c45bf602509..272e3de8723aeb6 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -86,14 +86,8 @@ void GPUToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.use64bitIndex = this->use64bitIndex;
SPIRVTypeConverter typeConverter(targetAttr, options);
-
- typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()](
- gpu::MMAMatrixType type) -> Type {
- if (useNV)
- return convertMMAToSPIRVCoopMatrixNVType(type);
-
- return convertMMAToSPIRVCoopMatrixType(type);
- });
+ populateMMAToSPIRVCoopMatrixTypeConversion(typeConverter,
+ this->useCoopMatrixNV);
RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index eb7fcb63d920d8f..4a4281aaaf0dbc4 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -311,14 +311,21 @@ struct WmmaLoadOpToSPIRVLowering final
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = subgroupMmaLoadMatrixOp->getLoc();
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+
gpu::MMAMatrixType retType =
cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
auto memrefType =
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType());
- Value bufferPtr = spirv::getElementPtr(
- *getTypeConverter<const SPIRVTypeConverter>(), memrefType,
- adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
- auto coopType = convertMMAToSPIRVCoopMatrixNVType(retType);
+ Value bufferPtr =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
+ adaptor.getIndices(), loc, rewriter);
+ auto coopType =
+ typeConverter.convertType<spirv::CooperativeMatrixNVType>(retType);
+ if (!coopType)
+ return rewriter.notifyMatchFailure(subgroupMmaLoadMatrixOp,
+ "type conversion failed");
+
int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
auto i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
@@ -385,30 +392,6 @@ struct WmmaMmaOpToSPIRVLowering final
} // namespace nv
} // namespace mlir
-mlir::spirv::CooperativeMatrixNVType
-mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
- ArrayRef<int64_t> retTypeShape = type.getShape();
- Type elementType = type.getElementType();
- return spirv::CooperativeMatrixNVType::get(
- elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
-}
-
-mlir::spirv::CooperativeMatrixType
-mlir::convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type) {
- ArrayRef<int64_t> retTypeShape = type.getShape();
- Type elementType = type.getElementType();
-
- auto use =
- llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
- .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
- .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
- .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
-
- return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
- retTypeShape[1],
- spirv::Scope::Subgroup, use);
-}
-
void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
using namespace mlir;
@@ -432,3 +415,31 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
/*benefit=*/2);
}
+
+void mlir::populateMMAToSPIRVCoopMatrixTypeConversion(
+ mlir::SPIRVTypeConverter &typeConverter, bool useNVTypes) {
+ if (useNVTypes) {
+ typeConverter.addConversion([](gpu::MMAMatrixType type) {
+ ArrayRef<int64_t> retTypeShape = type.getShape();
+ Type elementType = type.getElementType();
+ return spirv::CooperativeMatrixNVType::get(
+ elementType, spirv::Scope::Subgroup, retTypeShape[0],
+ retTypeShape[1]);
+ });
+ return;
+ }
+
+ typeConverter.addConversion([](gpu::MMAMatrixType type) {
+ ArrayRef<int64_t> retTypeShape = type.getShape();
+ Type elementType = type.getElementType();
+ auto use =
+ llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
+ .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
+ .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
+ .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
+
+ return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
+ retTypeShape[1],
+ spirv::Scope::Subgroup, use);
+ });
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/67485
More information about the Mlir-commits
mailing list