[Mlir-commits] [mlir] [mlir][spirv][gpu] Clean up wmma to coop matrix NV conversion. NFC. (PR #66278)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Sep 13 12:27:16 PDT 2023
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/66278:
This is a cleanup in preparation for adding a second conversion path using the KHR cooperative matrix extension.
Make the existing lowering explicit about emitting ops from the NV coop matrix extension. Clean up surrounding code.
>From 01aa362e74aa4a3885c4a8644952147c50616706 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 13 Sep 2023 15:24:34 -0400
Subject: [PATCH] [mlir][spirv][gpu] Clean up wmma to coop matrix NV
conversion. NFC.
This is a cleanup in preparation for adding a second conversion path
using the KHR cooperative matrix extension.
Make the existing lowering explicit about emitting ops from the NV coop
matrix extension. Clean up surrounding code.
---
.../mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h | 14 ++--
.../Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp | 11 +--
.../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | 75 ++++++++++---------
... => wmma-ops-to-spirv-nv-coop-matrix.mlir} | 2 +-
4 files changed, 55 insertions(+), 47 deletions(-)
rename mlir/test/Conversion/GPUToSPIRV/{wmma-ops-to-spirv.mlir => wmma-ops-to-spirv-nv-coop-matrix.mlir} (98%)
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index 3c3281513f60d89..6c4643da1884900 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -30,11 +30,15 @@ class MMAMatrixType;
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
-/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV.
-void populateGpuWMMAToSPIRVConversionPatterns(SPIRVTypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
-spirv::CooperativeMatrixNVType convertMMAToSPIRVType(gpu::MMAMatrixType type);
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
+/// using the NV Cooperative Matrix extension.
+void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
+ SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
+
+/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
+/// `type`.
+spirv::CooperativeMatrixNVType
+convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index f37c70a771f5916..d0ce58597f980d4 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -39,8 +39,7 @@ namespace {
/// replace it).
///
/// 2) Lower the body of the spirv::ModuleOp.
-class GPUToSPIRVPass : public impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
-public:
+struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
explicit GPUToSPIRVPass(bool mapMemorySpace)
: mapMemorySpace(mapMemorySpace) {}
void runOnOperation() override;
@@ -48,7 +47,6 @@ class GPUToSPIRVPass : public impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
private:
bool mapMemorySpace;
};
-} // namespace
void GPUToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
@@ -89,11 +87,12 @@ void GPUToSPIRVPass::runOnOperation() {
options.use64bitIndex = this->use64bitIndex;
SPIRVTypeConverter typeConverter(targetAttr, options);
typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
- return convertMMAToSPIRVType(type);
+ return convertMMAToSPIRVCoopMatrixNVType(type);
});
RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);
- populateGpuWMMAToSPIRVConversionPatterns(typeConverter, patterns);
+ populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
+ patterns);
// TODO: Change SPIR-V conversion to be progressive and remove the following
// patterns.
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
@@ -105,6 +104,8 @@ void GPUToSPIRVPass::runOnOperation() {
}
}
+} // namespace
+
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertGPUToSPIRVPass(bool mapMemorySpace) {
return std::make_unique<GPUToSPIRVPass>(mapMemorySpace);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 3851fb728b6654b..bf3fff027fe384a 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -1,4 +1,4 @@
-//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===//
+//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
-// SPIRV Dialect ops.
+// SPIRV Cooperative Matrix ops.
//
//===----------------------------------------------------------------------===//
@@ -22,7 +22,8 @@
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/TypeUtilities.h"
-using namespace mlir;
+namespace mlir::nv {
+namespace {
/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
/// when the elementwise op directly supports with cooperative matrix type.
@@ -70,12 +71,10 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
return false;
}
-namespace {
-
-/// This class implements the conversion of GPU MMA loadOp to
-/// CooperativeMatrixLoad op in the SPIRV dialect.
-struct WmmaLoadOpToSPIRVLowering
- : public OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
+/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
+/// dialect.
+struct WmmaLoadOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -90,7 +89,7 @@ struct WmmaLoadOpToSPIRVLowering
Value bufferPtr = spirv::getElementPtr(
*getTypeConverter<const SPIRVTypeConverter>(), memrefType,
adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
- auto coopType = convertMMAToSPIRVType(retType);
+ auto coopType = convertMMAToSPIRVCoopMatrixNVType(retType);
int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
auto i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
@@ -105,10 +104,10 @@ struct WmmaLoadOpToSPIRVLowering
}
};
-/// This class implements the conversion of GPU MMA StoreOp to
-/// CooperativeMatrixStore op in the SPIRV dialect.
-struct WmmaStoreOpToSPIRVLowering
- : public OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
+/// Converts the GPU MMA StoreOp to NVCooperativeMatrixStore op in the SPIRV
+/// dialect.
+struct WmmaStoreOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -136,10 +135,10 @@ struct WmmaStoreOpToSPIRVLowering
}
};
-/// This class implements the conversion of GPU MMA Compute to
-/// CooperativeMatrixMulAdd op in the SPIRV dialect.
-struct WmmaMmaOpToSPIRVLowering
- : public OpConversionPattern<gpu::SubgroupMmaComputeOp> {
+/// Converts GPU MMA Compute to
+/// NVCooperativeMatrixMulAdd op in the SPIRV dialect.
+struct WmmaMmaOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -153,9 +152,10 @@ struct WmmaMmaOpToSPIRVLowering
}
};
-/// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops.
-struct WmmaConstantOpToSPIRVLowering
- : public OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
+/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
+/// ops.
+struct WmmaConstantOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -163,7 +163,7 @@ struct WmmaConstantOpToSPIRVLowering
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value cst = adaptor.getOperands()[0];
- auto coopType = convertMMAToSPIRVType(
+ auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
subgroupMmaConstantMatrixOp, coopType, cst);
@@ -173,8 +173,8 @@ struct WmmaConstantOpToSPIRVLowering
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// the default case.
-struct WmmaElementwiseOpToSPIRVDefaultLowering
- : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+struct WmmaElementwiseOpToSPIRVDefaultLowering final
+ : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -186,7 +186,7 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
return failure();
}
- auto coopType = convertMMAToSPIRVType(
+ auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
adaptor.getOperands()));
@@ -195,8 +195,8 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// matrix times scalar case.
-struct WmmaElementwiseOpToSPIRVScalarMulLowering
- : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+struct WmmaElementwiseOpToSPIRVScalarMulLowering final
+ : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -238,7 +238,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
assert(cc.getConstituents().size() == 1);
scalar = cc.getConstituents().front();
- auto coopType = convertMMAToSPIRVType(
+ auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
elementwiseOp, coopType, ValueRange{matrix, scalar});
@@ -247,23 +247,26 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
};
} // namespace
+} // namespace mlir::nv
-/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
mlir::spirv::CooperativeMatrixNVType
-mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
+mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
ArrayRef<int64_t> retTypeShape = type.getShape();
Type elementType = type.getElementType();
return spirv::CooperativeMatrixNVType::get(
elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
}
-void mlir::populateGpuWMMAToSPIRVConversionPatterns(
+void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
+ using namespace mlir;
MLIRContext *context = patterns.getContext();
- patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
- WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
- WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
+ patterns
+ .add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
+ nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering,
+ nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
// Give the following patterns higher benefit to prevail over the default one.
- patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
- /*benefit=*/2);
+ patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter,
+ context,
+ /*benefit=*/2);
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
similarity index 98%
rename from mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
rename to mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
index a53eca65fc98699..5811c791f308d1e 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s
module attributes {
gpu.container_module,
More information about the Mlir-commits
mailing list