[Mlir-commits] [mlir] [mlir][spirv][gpu] Convert remaining wmma ops to KHR coop matrix (PR #66455)
Jakub Kuderski
llvmlistbot at llvm.org
Thu Sep 14 19:18:57 PDT 2023
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/66455:
>From 992b57544e7261bb18741e13fbe4b9634e488299 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 14 Sep 2023 21:06:43 -0400
Subject: [PATCH] [mlir][spirv][gpu] Convert remaining wmma ops to KHR coop
matrix
These do not produce extension-specific ops and are handled via common
patterns for both the KHR and the NV coop matrix extension.
Also improve match failure reporting and error handling in type
conversion.
---
.../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | 231 ++++++++++--------
.../wmma-ops-to-spirv-khr-coop-matrix.mlir | 96 +++++++-
2 files changed, 224 insertions(+), 103 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index d73cd5686d66e92..eb7fcb63d920d8f 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -24,11 +24,17 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include <cassert>
namespace mlir {
+//===----------------------------------------------------------------------===//
+// Patterns and helpers used by both the KHR and the NV lowering paths.
+//===----------------------------------------------------------------------===//
+
/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
/// when the elementwise op directly supports with cooperative matrix type.
/// Returns false if cannot.
@@ -77,6 +83,119 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
return false;
}
+bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
+ assert(!operands.empty());
+ if (!llvm::all_equal(
+ llvm::map_range(operands, [](Value v) { return v.getType(); })))
+ return false;
+
+ return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
+ operands.front().getType());
+}
+
+namespace {
+/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
+/// matrix ops.
+struct WmmaConstantOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() == 1);
+ Value cst = adaptor.getOperands().front();
+ auto coopType = getTypeConverter()->convertType(op.getType());
+ if (!coopType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
+ return success();
+ }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// the default case.
+struct WmmaElementwiseOpToSPIRVDefaultLowering final
+ : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // All operands should be of cooperative matrix types.
+ if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
+ return rewriter.notifyMatchFailure(op,
+ "not all operands are coop matrices");
+ }
+
+ auto coopType = getTypeConverter()->convertType(op.getType());
+ if (!coopType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ return success(
+ createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
+ }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// matrix times scalar case.
+struct WmmaElementwiseOpToSPIRVScalarMulLowering final
+ : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (adaptor.getOperands().size() != 2)
+ return failure();
+
+ // All operands should be of cooperative matrix types.
+ if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
+ return rewriter.notifyMatchFailure(op,
+ "not all operands are coop matrices");
+ }
+
+ if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
+ return failure();
+
+ // Use the original operands to check whether one of the operands is a splat
+ // scalar value.
+ Value lhs = op.getOperands().front();
+ Value rhs = op.getOperands().back();
+ Value splat = nullptr;
+ Value matrix = nullptr;
+ if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+ splat = adaptor.getOperands().front();
+ matrix = adaptor.getOperands().back();
+ } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+ matrix = adaptor.getOperands().front();
+ splat = adaptor.getOperands().back();
+ }
+ if (!splat || !matrix)
+ return rewriter.notifyMatchFailure(op, "no splat operand");
+
+ // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
+ Value scalar;
+ auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
+ if (!cc) {
+ return rewriter.notifyMatchFailure(op,
+ "splat is not a composite construct");
+ }
+
+ assert(cc.getConstituents().size() == 1);
+ scalar = cc.getConstituents().front();
+
+ auto coopType = getTypeConverter()->convertType(op.getType());
+ if (!coopType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+ rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
+ op, coopType, ValueRange{matrix, scalar});
+ return success();
+ }
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// SPV_KHR_cooperative_matrix
//===----------------------------------------------------------------------===//
@@ -262,100 +381,6 @@ struct WmmaMmaOpToSPIRVLowering final
}
};
-/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
-/// ops.
-struct WmmaConstantOpToSPIRVLowering final
- : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
- OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Value cst = adaptor.getOperands()[0];
- auto coopType = convertMMAToSPIRVCoopMatrixNVType(
- cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
- rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
- subgroupMmaConstantMatrixOp, coopType, cst);
- return success();
- }
-};
-
-/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
-/// the default case.
-struct WmmaElementwiseOpToSPIRVDefaultLowering final
- : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
- OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // All operands should be of cooperative matrix types.
- for (Value operand : adaptor.getOperands()) {
- if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
- return failure();
- }
- auto coopType = convertMMAToSPIRVCoopMatrixNVType(
- cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
- return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
- adaptor.getOperands()));
- }
-};
-
-/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
-/// matrix times scalar case.
-struct WmmaElementwiseOpToSPIRVScalarMulLowering final
- : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
- OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- if (adaptor.getOperands().size() != 2)
- return failure();
- // All operands should be of cooperative matrix types.
- for (Value operand : adaptor.getOperands()) {
- if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
- return failure();
- }
-
- if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF)
- return failure();
-
- // Use the original operands to check whether one of the operands is a splat
- // scalar value.
- Value lhs = elementwiseOp.getOperands().front();
- Value rhs = elementwiseOp.getOperands().back();
- Value splat = nullptr;
- Value matrix = nullptr;
- if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
- splat = adaptor.getOperands().front();
- matrix = adaptor.getOperands().back();
- } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
- matrix = adaptor.getOperands().front();
- splat = adaptor.getOperands().back();
- }
- if (!splat || !matrix)
- return failure();
-
- // Constant MMA matrix ops are converted to spirv.CompositeConstruct ops.
- Value scalar = nullptr;
- auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
- if (!cc)
- return failure();
- assert(cc.getConstituents().size() == 1);
- scalar = cc.getConstituents().front();
-
- auto coopType = convertMMAToSPIRVCoopMatrixNVType(
- cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
- rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
- elementwiseOp, coopType, ValueRange{matrix, scalar});
- return success();
- }
-};
-
} // namespace
} // namespace nv
} // namespace mlir
@@ -389,19 +414,21 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
using namespace mlir;
MLIRContext *context = patterns.getContext();
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
- khr::WmmaStoreOpToSPIRVLowering>(converter, context);
+ khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
+ WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
+ // Give the following patterns higher benefit to prevail over the default one.
+ patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
+ /*benefit=*/2);
}
void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
using namespace mlir;
MLIRContext *context = patterns.getContext();
- patterns
- .add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
- nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering,
- nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
+ patterns.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
+ nv::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
+ WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
// Give the following patterns higher benefit to prevail over the default one.
- patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter,
- context,
- /*benefit=*/2);
+ patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
+ /*benefit=*/2);
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 0818791b98471da..f129cc8ce84ec39 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -69,12 +69,106 @@ module attributes {
-> !gpu.mma_matrix<16x16xf16, "COp">
%i = arith.constant 0 : index
- // CHECK: spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, <RowMajor>
+ // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAD]], %{{.+}}, <RowMajor>
gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}
+ // CHECK-LABEL: spirv.func @gpu_wmma_constant_op
+ gpu.func @gpu_wmma_constant_op(%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+ // CHECK: %[[CST1F:.+]] = spirv.Constant 1.000000e+00 : f16
+ %cst = arith.constant 1.0 : f16
+ // CHECK: %[[MAT:.+]] = spirv.CompositeConstruct %[[CST1F]] :
+ // CHECK-SAME: (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
+
+ %i = arith.constant 0 : index
+ // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAT]], %{{.+}}, <RowMajor>
+ gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
+ !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+ // CHECK: spirv.Return
+ gpu.return
+ }
+
+ // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
+ // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">,
+ %B: !gpu.mma_matrix<16x16xf16, "COp">,
+ %ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+ // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ %C = gpu.subgroup_mma_elementwise addf %A, %B :
+ (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ %D = gpu.subgroup_mma_elementwise negatef %C :
+ (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ %E = gpu.subgroup_mma_elementwise divf %D, %A :
+ (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} :
+ // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+ %F = gpu.subgroup_mma_elementwise extf %E :
+ (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+
+ %i = arith.constant 0 : index
+ // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
+ gpu.subgroup_mma_store_matrix %F, %ptr[%i,%i] {leadDimension = 32 : index} :
+ !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
+ // CHECK: spirv.Return
+ gpu.return
+ }
+
+ // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
+ // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ // CHECK-SAME: %[[S:.+]]: f16
+ gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(
+ %A: !gpu.mma_matrix<16x16xf16, "COp">, %scalar: f16,
+ %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+ %i = arith.constant 0 : index
+
+ %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: %[[C:.+]] = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
+ // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
+ %C = gpu.subgroup_mma_elementwise mulf %A, %B :
+ (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
+ !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+
+ // CHECK: %[[D:.+]] = spirv.MatrixTimesScalar %[[C]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
+ // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[D]], %{{.+}}, <RowMajor>
+ %D = gpu.subgroup_mma_elementwise mulf %B, %C :
+ (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
+ !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+ // CHECK: spirv.Return
+ gpu.return
+ }
+
+ // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
+ // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ // CHECK-SAME: %[[S:.+]]: f16
+ gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(
+ %A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16,
+ %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+ %i = arith.constant 0 : index
+
+ // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: %[[C:.+]] = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ %C = gpu.subgroup_mma_elementwise addf %A, %B :
+ (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+
+ // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
+ gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
+ !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+ // CHECK: spirv.Return
+ gpu.return
+ }
}
}
More information about the Mlir-commits
mailing list