[Mlir-commits] [mlir] c441070 - [mlir][spirv] Add conversion from GPU WMMA ops to SPIRV Cooperative matrix
Nirvedh Meshram
llvmlistbot at llvm.org
Sat Oct 22 18:32:01 PDT 2022
Author: Nirvedh Meshram
Date: 2022-10-22T18:29:40-07:00
New Revision: c44107066550bfbdd86af71d21008599e871744a
URL: https://github.com/llvm/llvm-project/commit/c44107066550bfbdd86af71d21008599e871744a
DIFF: https://github.com/llvm/llvm-project/commit/c44107066550bfbdd86af71d21008599e871744a.diff
LOG: [mlir][spirv] Add conversion from GPU WMMA ops to SPIRV Cooperative matrix
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D136521
Added:
mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
Modified:
mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/test/Dialect/GPU/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index 132866f756f9a..3c3281513f60d 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -13,16 +13,28 @@
#ifndef MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
#define MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class SPIRVTypeConverter;
+namespace gpu {
+class MMAMatrixType;
+} // namespace gpu
+
/// Appends to a pattern list additional patterns for translating GPU Ops to
/// SPIR-V ops. For a gpu.func to be converted, it should have a
/// spirv.entry_point_abi attribute.
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV.
+void populateGpuWMMAToSPIRVConversionPatterns(SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+spirv::CooperativeMatrixNVType convertMMAToSPIRVType(gpu::MMAMatrixType type);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f5dedb04b0778..0cf2029a232df 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1246,19 +1246,35 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
}];
}
-def GPU_ElementwiseOpAdd : I32EnumAttrCase<"ADDF", 0, "addf">;
-def GPU_ElementwiseOpMul : I32EnumAttrCase<"MULF", 1, "mulf">;
-def GPU_ElementwiseOpMaxF : I32EnumAttrCase<"MAXF", 2, "maxf">;
-def GPU_ElementwiseOpMinF : I32EnumAttrCase<"MINF", 3, "minf">;
-def GPU_ElementwiseOpDivF : I32EnumAttrCase<"DIVF", 4, "divf">;
+def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">;
+def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">;
+def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">;
+def GPU_ElementwiseOpMaxF : I32EnumAttrCase<"MAXF", 3, "maxf">;
+def GPU_ElementwiseOpMinF : I32EnumAttrCase<"MINF", 4, "minf">;
+def GPU_ElementwiseOpDivF : I32EnumAttrCase<"DIVF", 5, "divf">;
+def GPU_ElementwiseOpAddI : I32EnumAttrCase<"ADDI", 6, "addi">;
+def GPU_ElementwiseOpMulI : I32EnumAttrCase<"MULI", 7, "muli">;
+def GPU_ElementwiseOpSUBI : I32EnumAttrCase<"SUBI", 8, "subi">;
+def GPU_ElementwiseOpDivS : I32EnumAttrCase<"DIVS", 9, "divs">;
+def GPU_ElementwiseOpDivU : I32EnumAttrCase<"DIVU", 10, "divu">;
+def GPU_ElementwiseOpNEGF : I32EnumAttrCase<"NEGATEF", 11, "negatef">;
+def GPU_ElementwiseOpNEGS : I32EnumAttrCase<"NEGATES", 12, "negates">;
def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
"elementwise operation to apply to mma matrix", [
- GPU_ElementwiseOpAdd,
- GPU_ElementwiseOpMul,
+ GPU_ElementwiseOpAddF,
+ GPU_ElementwiseOpMulF,
+ GPU_ElementwiseOpSUBF,
GPU_ElementwiseOpMaxF,
GPU_ElementwiseOpMinF,
- GPU_ElementwiseOpDivF
+ GPU_ElementwiseOpDivF,
+ GPU_ElementwiseOpAddI,
+ GPU_ElementwiseOpMulI,
+ GPU_ElementwiseOpSUBI,
+ GPU_ElementwiseOpDivS,
+ GPU_ElementwiseOpDivU,
+ GPU_ElementwiseOpNEGF,
+ GPU_ElementwiseOpNEGS
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::gpu";
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 50b5c4932f0ca..b124625f70543 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -42,7 +42,18 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
// Operand type same as result type.
SPIRV_UnaryOp<mnemonic, type, type,
!listconcat(traits,
- [Pure, SameOperandsAndResultType])>;
+ [Pure, SameOperandsAndResultType])> {
+ // In addition to normal types arithmetic instructions can support cooperative
+ // matrix.
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$result
+ );
+ let assemblyFormat = "operands attr-dict `:` type($result)";
+ }
// -----
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 8ecbdf240c773..963a718dc5f36 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -311,8 +311,9 @@ static Value createScalarOp(OpBuilder &builder, Location loc,
case gpu::MMAElementwiseOp::MINF:
return createMinMaxF(builder, loc, operands[0], operands[1],
/*isMin=*/true);
+ default:
+ llvm_unreachable("unknown op");
}
- llvm_unreachable("unknown op");
}
/// Convert GPU MMA elementwise ops to extract + op + insert.
diff --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
index 7ee339ab7f147..3b3156d090dfa 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_conversion_library(MLIRGPUToSPIRV
GPUToSPIRV.cpp
GPUToSPIRVPass.cpp
+ WmmaOpsToSPIRV.cpp
DEPENDS
MLIRConversionPassIncGen
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index c425346e45766..f1c4e32da827f 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -86,9 +86,12 @@ void GPUToSPIRVPass::runOnOperation() {
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
+ typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
+ return convertMMAToSPIRVType(type);
+ });
RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);
-
+ populateGpuWMMAToSPIRVConversionPatterns(typeConverter, patterns);
// TODO: Change SPIR-V conversion to be progressive and remove the following
// patterns.
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
new file mode 100644
index 0000000000000..c890d41d87b66
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -0,0 +1,203 @@
+//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
+// SPIRV Dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
+#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+// See SPV_NV_cooperative_matrix for supported element wise ops.
+static void createElementWiseOp(ConversionPatternRewriter &builder,
+ gpu::SubgroupMmaElementwiseOp op,
+ spirv::CooperativeMatrixNVType coopType,
+ ValueRange operands) {
+ switch (op.getOpType()) {
+ case gpu::MMAElementwiseOp::ADDF:
+ builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::ADDI:
+ builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::SUBF:
+ builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::SUBI:
+ builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::DIVF:
+ builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::DIVS:
+ builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::DIVU:
+ builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::NEGATEF:
+ builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::NEGATES:
+ builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
+ return;
+ default:
+ llvm_unreachable("unknown op");
+ }
+}
+
+namespace {
+
+/// This class implements the conversion of GPU MMA loadOp to
+/// CooperativeMatrixLoad op in the SPIRV dialect.
+struct WmmaLoadOpToSPIRVLowering
+ : public OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = subgroupMmaLoadMatrixOp->getLoc();
+ gpu::MMAMatrixType retType =
+ subgroupMmaLoadMatrixOp.getRes().getType().cast<gpu::MMAMatrixType>();
+ auto memrefType =
+ subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast<MemRefType>();
+ Value bufferPtr = spirv::getElementPtr(
+ *getTypeConverter<SPIRVTypeConverter>(), memrefType,
+ adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
+ auto coopType = convertMMAToSPIRVType(retType);
+ int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
+ auto i32Type = rewriter.getI32Type();
+ auto strideValue = rewriter.create<spirv::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, stride));
+ auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
+ loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
+ rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixLoadOp>(
+ subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, coloumnMajor,
+ spirv::MemoryAccessAttr());
+ return success();
+ }
+};
+
+/// This class implements the conversion of GPU MMA StoreOp to
+/// CooperativeMatrixStore op in the SPIRV dialect.
+struct WmmaStoreOpToSPIRVLowering
+ : public OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = subgroupMmaStoreMatrixOp->getLoc();
+ auto memrefType =
+ subgroupMmaStoreMatrixOp.getDstMemref().getType().cast<MemRefType>();
+ Value bufferPtr = spirv::getElementPtr(
+ *getTypeConverter<SPIRVTypeConverter>(), memrefType,
+ adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter);
+ int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue();
+ auto i32Type = rewriter.getI32Type();
+ auto strideValue = rewriter.create<spirv::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, stride));
+ auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
+ loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
+ rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixStoreOp>(
+ subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue,
+ coloumnMajor, spirv::MemoryAccessAttr());
+ return success();
+ }
+};
+
+/// This class implements the conversion of GPU MMA Compute to
+/// CooperativeMatrixMulAdd op in the SPIRV dialect.
+struct WmmaMmaOpToSPIRVLowering
+ : public OpConversionPattern<gpu::SubgroupMmaComputeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixMulAddOp>(
+ subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(),
+ adaptor.getOpB(), adaptor.getOpC());
+ return success();
+ }
+};
+
+/// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops.
+struct WmmaConstantOpToSPIRVLowering
+ : public OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value cst = adaptor.getOperands()[0];
+ auto coopType = convertMMAToSPIRVType(
+ subgroupMmaConstantMatrixOp.getType().cast<gpu::MMAMatrixType>());
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
+ subgroupMmaConstantMatrixOp, coopType, cst);
+ return success();
+ }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops.
+struct WmmaElementwiseOpToSPIRVLowering
+ : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // All operands should be of cooperative matrix types.
+ for (Value operand : adaptor.getOperands()) {
+ if (!operand.getType().isa<spirv::CooperativeMatrixNVType>())
+ return failure();
+ }
+ auto coopType = convertMMAToSPIRVType(
+ subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
+ createElementWiseOp(rewriter, subgroupMmaElementwiseOp, coopType,
+ adaptor.getOperands());
+ return success();
+ }
+};
+
+} // namespace
+
+/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
+mlir::spirv::CooperativeMatrixNVType
+mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
+ ArrayRef<int64_t> retTypeShape = type.getShape();
+ Type elementType = type.getElementType();
+ return spirv::CooperativeMatrixNVType::get(
+ elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
+}
+
+void mlir::populateGpuWMMAToSPIRVConversionPatterns(
+ SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
+ WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
+ WmmaElementwiseOpToSPIRVLowering>(converter,
+ patterns.getContext());
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 4f5cfb65c924a..3eee6081e7eef 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1192,18 +1192,11 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
auto resMatrixType = resType.cast<gpu::MMAMatrixType>();
auto operand = resMatrixType.getOperand();
auto srcMemrefType = srcType.cast<MemRefType>();
- auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
if (!isLastMemrefDimUnitStride(srcMemrefType))
return emitError(
"expected source memref most minor dim must have unit stride");
- if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
- srcMemSpace != kGlobalMemorySpace)
- return emitError(
- "source memorySpace kGenericMemorySpace, kSharedMemorySpace or "
- "kGlobalMemorySpace only allowed");
-
if (!operand.equals("AOp") && !operand.equals("BOp") &&
!operand.equals("COp"))
return emitError("only AOp, BOp and COp can be loaded");
@@ -1220,17 +1213,11 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
auto dstType = getDstMemref().getType();
auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
auto dstMemrefType = dstType.cast<MemRefType>();
- auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
if (!isLastMemrefDimUnitStride(dstMemrefType))
return emitError(
"expected destination memref most minor dim must have unit stride");
- if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&
- dstMemSpace != kGlobalMemorySpace)
- return emitError("destination memorySpace of kGenericMemorySpace, "
- "kGlobalMemorySpace or kSharedMemorySpace only allowed");
-
if (!srcMatrixType.getOperand().equals("COp"))
return emitError(
"expected the operand matrix being stored to have 'COp' operand type");
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
new file mode 100644
index 0000000000000..9f2a27cf0e864
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
@@ -0,0 +1,110 @@
+// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -split-input-file -verify-diagnostics %s | FileCheck %s
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-LABEL: spirv.func @gpu_wmma_load_op
+ // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+ // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
+ gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+ %i = arith.constant 16 : index
+ %j = arith.constant 16 : index
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<f32, StorageBuffer> as !spirv.coopmatrix<16x16xf16, Subgroup>
+ %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-LABEL: spirv.func @gpu_wmma_store_op
+ // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
+ // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
+ gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+ %i = arith.constant 16 : index
+ %j = arith.constant 16 : index
+ // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
+ gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-LABEL: spirv.func @gpu_wmma_mma_op
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 2)>})
+ // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
+ gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
+ %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-LABEL: spirv.func @gpu_wmma_constant_op
+ gpu.func @gpu_wmma_constant_op() kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+ // CHECK: {{%.*}} = spirv.Constant
+ %cst = arith.constant 1.0 : f16
+ // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup>
+ %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
+ gpu.func @gpu_wmma_elementwise_op(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+ // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+ %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+ %D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+ %E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index b029d2fa7c9a4..6390de3e3a10d 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -515,16 +515,6 @@ func.func @mmaLoadOp_identity_layout(){
// -----
-func.func @mmaLoadOp_invalid_mem_space(){
- %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 5>
- %i = arith.constant 16 : index
- // expected-error @+1 {{source memorySpace kGenericMemorySpace, kSharedMemorySpace or kGlobalMemorySpace only allowed}}
- %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 5> -> !gpu.mma_matrix<16x16xf16, "AOp">
- return
-}
-
-// -----
-
#layout_map_col_major = affine_map<(i, j) -> (j, i)>
func.func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
@@ -538,17 +528,6 @@ func.func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) ->
// -----
-func.func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
- %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5>
- %i = arith.constant 16 : index
- %j = arith.constant 16 : index
- // expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}}
- gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 5>
- return
-}
-
-// -----
-
func.func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp">) -> () {
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index
More information about the Mlir-commits
mailing list