[Mlir-commits] [mlir] 2c7827d - [mlir][spirv] Add GPU subgroup MMA to spirv.MMAMatrixTimesScalar
Lei Zhang
llvmlistbot at llvm.org
Mon Dec 5 14:30:59 PST 2022
Author: Lei Zhang
Date: 2022-12-05T22:30:50Z
New Revision: 2c7827da4f5bf758a1659a0c4a2d0e7944827c42
URL: https://github.com/llvm/llvm-project/commit/2c7827da4f5bf758a1659a0c4a2d0e7944827c42
DIFF: https://github.com/llvm/llvm-project/commit/2c7827da4f5bf758a1659a0c4a2d0e7944827c42.diff
LOG: [mlir][spirv] Add GPU subgroup MMA to spirv.MMAMatrixTimesScalar
Along the way, make the default pattern fail instead of crashing
when an elementwise op is not supported yet.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D139280
Added:
Modified:
mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 43d7f6237671..81e7d5c29437 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -24,42 +24,47 @@
using namespace mlir;
-// See SPV_NV_cooperative_matrix for supported element wise ops.
-static void createElementWiseOp(ConversionPatternRewriter &builder,
+/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
+/// when the elementwise op directly supports with cooperative matrix type.
+/// Returns false if cannot.
+///
+/// See SPV_NV_cooperative_matrix for supported elementwise ops.
+static bool createElementwiseOp(ConversionPatternRewriter &builder,
gpu::SubgroupMmaElementwiseOp op,
spirv::CooperativeMatrixNVType coopType,
ValueRange operands) {
switch (op.getOpType()) {
case gpu::MMAElementwiseOp::ADDF:
builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
- return;
+ return true;
case gpu::MMAElementwiseOp::ADDI:
builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
- return;
+ return true;
case gpu::MMAElementwiseOp::SUBF:
builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
- return;
+ return true;
case gpu::MMAElementwiseOp::SUBI:
builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
- return;
+ return true;
case gpu::MMAElementwiseOp::DIVF:
builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
- return;
+ return true;
case gpu::MMAElementwiseOp::DIVS:
builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
- return;
+ return true;
case gpu::MMAElementwiseOp::DIVU:
builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
- return;
+ return true;
case gpu::MMAElementwiseOp::NEGATEF:
builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
- return;
+ return true;
case gpu::MMAElementwiseOp::NEGATES:
builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
- return;
+ return true;
default:
- llvm_unreachable("unknown op");
+ break;
}
+ return false;
}
namespace {
@@ -163,13 +168,14 @@ struct WmmaConstantOpToSPIRVLowering
}
};
-/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops.
-struct WmmaElementwiseOpToSPIRVLowering
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// the default case.
+struct WmmaElementwiseOpToSPIRVDefaultLowering
: public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
+ matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// All operands should be of cooperative matrix types.
@@ -178,9 +184,58 @@ struct WmmaElementwiseOpToSPIRVLowering
return failure();
}
auto coopType = convertMMAToSPIRVType(
- subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
- createElementWiseOp(rewriter, subgroupMmaElementwiseOp, coopType,
- adaptor.getOperands());
+ elementwiseOp.getType().cast<gpu::MMAMatrixType>());
+ return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
+ adaptor.getOperands()));
+ }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// matrix times scalar case.
+struct WmmaElementwiseOpToSPIRVScalarMulLowering
+ : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (adaptor.getOperands().size() != 2)
+ return failure();
+ // All operands should be of cooperative matrix types.
+ for (Value operand : adaptor.getOperands()) {
+ if (!operand.getType().isa<spirv::CooperativeMatrixNVType>())
+ return failure();
+ }
+
+ // Use the original operands to check whether one of the operands is a splat
+ // scalar value.
+ Value lhs = elementwiseOp.getOperands().front();
+ Value rhs = elementwiseOp.getOperands().back();
+ Value splat = nullptr;
+ Value matrix = nullptr;
+ if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+ splat = adaptor.getOperands().front();
+ matrix = adaptor.getOperands().back();
+ } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+ matrix = adaptor.getOperands().front();
+ splat = adaptor.getOperands().back();
+ }
+ if (!splat || !matrix)
+ return failure();
+
+ // Constant MMA matrix ops are converted to spirv.CompositeConstruct ops.
+ Value scalar = nullptr;
+ auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
+ if (!cc)
+ return failure();
+ assert(cc.getConstituents().size() == 1);
+ scalar = cc.getConstituents().front();
+
+ auto coopType = convertMMAToSPIRVType(
+ elementwiseOp.getType().cast<gpu::MMAMatrixType>());
+ rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
+ elementwiseOp, coopType, ValueRange{matrix, scalar});
return success();
}
};
@@ -198,8 +253,11 @@ mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
void mlir::populateGpuWMMAToSPIRVConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
+ MLIRContext *context = patterns.getContext();
patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
- WmmaElementwiseOpToSPIRVLowering>(converter,
- patterns.getContext());
+ WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
+ // Give the following patterns higher benefit to prevail over the default one.
+ patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
+ /*benefit=*/2);
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
index c4dc7458bc31..b544ae5a1238 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
@@ -4,10 +4,8 @@ module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
gpu.module @kernels {
- // CHECK: spirv.module @{{.*}} Logical GLSL450 {
// CHECK-LABEL: spirv.func @gpu_wmma_load_op
- // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
- // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
+ // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
%i = arith.constant 16 : index
@@ -27,7 +25,6 @@ module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
gpu.module @kernels {
- // CHECK: spirv.module @{{.*}} Logical GLSL450 {
// CHECK-LABEL: spirv.func @gpu_wmma_load_op_transpose
// CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
// CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
@@ -50,11 +47,9 @@ module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
gpu.module @kernels {
- // CHECK: spirv.module @{{.*}} Logical GLSL450 {
// CHECK-LABEL: spirv.func @gpu_wmma_store_op
- // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
- // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
- // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
+ // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
%i = arith.constant 16 : index
@@ -74,7 +69,6 @@ module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
gpu.module @kernels {
- // CHECK: spirv.module @{{.*}} Logical GLSL450 {
// CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose
// CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
// CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
@@ -98,12 +92,10 @@ module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
gpu.module @kernels {
- // CHECK: spirv.module @{{.*}} Logical GLSL450 {
// CHECK-LABEL: spirv.func @gpu_wmma_mma_op
- // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
- // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
- // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 2)>})
- // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
+ // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
@@ -120,7 +112,6 @@ module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
gpu.module @kernels {
- // CHECK: spirv.module @{{.*}} Logical GLSL450 {
// CHECK-LABEL: spirv.func @gpu_wmma_constant_op
gpu.func @gpu_wmma_constant_op() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
@@ -140,11 +131,10 @@ module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
gpu.module @kernels {
- // CHECK: spirv.module @{{.*}} Logical GLSL450 {
- // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op
- // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
- // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
- gpu.func @gpu_wmma_elementwise_op(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+ // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
+ // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+ gpu.func @gpu_wmma_elementwise_op_default(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
%C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
@@ -157,3 +147,24 @@ module attributes {
}
}
}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
+ // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: %[[S:.+]]: f16
+ gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+ %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+ %C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+ %D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ gpu.return
+ }
+ }
+}
More information about the Mlir-commits
mailing list