[Mlir-commits] [mlir] 875eb52 - [MLIR][GPU][NVVM] Add warp synchronous matrix-multiply accumulate ops
Uday Bondhugula
llvmlistbot at llvm.org
Wed May 5 23:37:45 PDT 2021
Author: Navdeep Kumar
Date: 2021-05-06T12:06:25+05:30
New Revision: 875eb523c13249114507cb8facd797773e278d9e
URL: https://github.com/llvm/llvm-project/commit/875eb523c13249114507cb8facd797773e278d9e
DIFF: https://github.com/llvm/llvm-project/commit/875eb523c13249114507cb8facd797773e278d9e.diff
LOG: [MLIR][GPU][NVVM] Add warp synchronous matrix-multiply accumulate ops
Add warp synchronous matrix-multiply accumulate ops in GPU and NVVM
dialect. Add following three ops to GPU dialect :-
1.) subgroup_mma_load_matrix
2.) subgroup_mma_store_matrix
3.) subgroup_mma_compute
Add following three ops to NVVM dialect :-
1.) wmma.m16n16k16.load.[a,b,c].[f16,f32].row.stride
2.) wmma.m16n16k16.store.d.[f16,f32].row.stride
3.) wmma.m16n16k16.mma.row.row.[f16,f32].[f16,f32]
Reviewed By: bondhugula, ftynse, ThomasRaoux
Differential Revision: https://reviews.llvm.org/D95330
Added:
Modified:
mlir/include/mlir/Dialect/GPU/GPUBase.td
mlir/include/mlir/Dialect/GPU/GPUDialect.h
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/GPU/ops.mlir
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td
index 63cc7df9ed0d5..3945c3b8c47e3 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td
@@ -57,6 +57,17 @@ def GPU_AsyncToken : DialectType<
GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::AsyncTokenType>()">, "async token type">,
BuildableType<"mlir::gpu::AsyncTokenType::get($_builder.getContext())">;
+// Predicat to check if type is gpu::MMAMatrixType.
+def IsMMAMatrixTypePred : CPred<"$_self.isa<::mlir::gpu::MMAMatrixType>()">;
+
+def GPU_MMAMatrix : DialectType<
+ GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
+
+class MMAMatrixOf<list<Type> allowedTypes> :
+ ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
+ "$_self.cast<::mlir::gpu::MMAMatrixType>().getElementType()",
+ "gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
+
def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
let description = [{
Interface for GPU operations that execute asynchronously on the device.
@@ -102,4 +113,18 @@ def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
];
}
+// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing
+// the layouts of the operands supported by the ops that use this attribute.
+def RowMajor: StrEnumAttrCase<"RowMajor", 0>;
+def ColMajor: StrEnumAttrCase<"ColMajor", 1>;
+
+// Specifies a String enum Attribute for Warp wide matrix operations,
+// representing the layout of respective operands. The layout later governs
+// the lowerings to appropriate intrinsics.
+def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major",
+ [RowMajor, ColMajor]> {
+ let stringToSymbolFnName = "LayoutStrToEnum";
+ let symbolToStringFnName = "EnumToLayoutStr";
+}
+
#endif // GPU_BASE
diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
index 26ab171727146..7832d48945656 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -44,6 +44,122 @@ class AsyncTokenType
using Base::Base;
};
+/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
+/// and type.
+struct MMAMatrixStorageType : public TypeStorage {
+ MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes,
+ Type elementType, StringRef operand)
+ : dimShapes(dimShapes), numDims(numDims), elementType(elementType),
+ operand(operand) {}
+
+ /// The hash key for uniquing.
+ using KeyTy = std::tuple<ArrayRef<int64_t>, Type, StringRef>;
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(getShape(), elementType, operand);
+ }
+
+ /// Construction.
+ static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
+ StringRef operand = allocator.copyInto(std::get<2>(key));
+
+ return new (allocator.allocate<MMAMatrixStorageType>())
+ MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(key),
+ operand);
+ }
+
+ ArrayRef<int64_t> getShape() const {
+ return ArrayRef<int64_t>(dimShapes, numDims);
+ }
+
+ StringRef getOperand() const { return operand; }
+
+ /// Reference to the shape of the MMA matrix.
+ const int64_t *dimShapes;
+
+ /// Number of dimensions in the MMA matrix.
+ unsigned numDims;
+
+ /// Element type of elements held in the MMA matrix.
+ Type elementType;
+
+ /// MMA operand that this MMAMatrix holds. The general form of operation this
+ /// type supports is given by the equation D = (alpha*(A*B)) + (beta*C). This
+ /// field specifies which operand in the given equation is held by this type.
+ /// The valid values are "AOp", "BOp", "COp" and "DOp".
+ StringRef operand;
+};
+
+/// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply
+/// accumulate operations. MMAMatrices are taken as direct operands by these
+/// operations and are also produced as results. These matrices are meant to
+/// reside in the registers. A limited number of pointwise operations can be
+/// performed on these matrices, i.e., operations which operate uniformly on
+/// all the elements in the matrix and do not change the order of matrix
+/// elements. The above conditions exist because the layout of matrix elements
+/// inside the matrix is opaque i.e., the elements may be present in the
+/// matrix in any order. The general usage of this type is shown as follows:-
+///
+/// %0 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {leadDimension = 16 :
+/// index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+///
+/// The MMAMatrixType describes the shape of the matrix being loaded and the
+/// operand being loaded too. The operand needs to be specified to aid the
+/// lowering of this type to dialects such as NVVM where each workitem may
+/// hold
diff erent amount of elements depending on the elementType of the
+/// matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type
+/// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
+/// are:-
+///
+/// %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16,
+/// "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32,
+/// "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
+///
+///
+/// gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
+/// : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
+// TODO: consider moving this to ODS.
+class MMAMatrixType
+ : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
+public:
+ using Base::Base;
+
+ /// Get MMAMatrixType and verify construction Invariants.
+ static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
+ StringRef operand);
+
+ /// Get MMAMatrixType at a particular location and verify construction
+ /// Invariants.
+ static MMAMatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType,
+ StringRef operand);
+
+ /// Check if a type is valid a MMAMatrixType elementType.
+ static bool isValidElementType(Type elementType);
+
+ /// Verify that shape and elementType are actually allowed for the
+ /// MMAMatrixType.
+ static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType,
+ StringRef operand);
+
+ /// Get number of dims.
+ unsigned getNumDims() const;
+
+ /// Get shape of the matrix.
+ ArrayRef<int64_t> getShape() const;
+
+ /// Get elementType of a single element.
+ Type getElementType() const;
+
+ /// The general form of operation this type supports is given by the equation
+ /// D = (alpha*(A*B)) + (beta*C). This function returns which operand in the
+ /// given equation is held by this type. String returned can be one of"AOp",
+ /// "BOp", "COp" and "DOp".
+ StringRef getOperand() const;
+};
+
// Adds a `gpu.async.token` to the front of the argument list.
void addAsyncDependency(Operation *op, Value token);
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 41206af46ae2d..5bd6956e75c0c 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -912,4 +912,122 @@ def GPU_MemcpyOp : GPU_Op<"memcpy", [GPU_AsyncOpInterface]> {
let verifier = [{ return ::verify(*this); }];
}
+def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
+ [MemoryEffects<[MemRead]>]>{
+
+ let summary = "GPU warp synchronous matrix load";
+
+ let description = [{
+ The `gpu.subgroup_mma_load_matrix` operation loads a matrix collectively
+ using all the threads in a subgroup.
+
+ This operation takes a memref as argument. It is the source matrix from which
+ data is to be loaded. The op returns a `!gpu.mma_matrix`. The source memref
+ can be in the global or shared memory space. The starting of the load address
+ is determined using indices provided. The matrix being loaded is specified in
+ the result type. This attribute is necessary because there exists a
diff erent
+ LLVM intrinsic for loading each operand, This is probably because all operands
+ need to be laid out in a specific/
diff erent way for the operation in the registers.
+ `leadDimension` attribute specifies the leading dimension of the source matrix.
+
+ This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
+ `gpu.subgroup_mma_compute`.
+
+ Example:
+
+ ```mlir
+ %0 = gpu.subgroup_mma_load_matrix src[%i,%j] : {leadDimension = 32
+ : i32} : memref<32x32xf16, 3>, !gpu.mma_matrix<16x16xf16, "AOp">
+ ```
+ }];
+
+ let arguments = (ins Arg<MemRefRankOf<[F16, F32], [2]>, "", [MemRead]>:$srcMemref,
+ Variadic<Index>:$indices,
+ IndexAttr:$leadDimension);
+
+ let results = (outs GPU_MMAMatrix:$res);
+
+ let assemblyFormat = [{
+ $srcMemref`[`$indices`]` attr-dict `:` type($srcMemref) `->` type($res)
+ }];
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
+ [MemoryEffects<[MemWrite]>]>{
+
+ let summary = "GPU warp synchronous matrix store";
+
+ let description = [{
+ The `gpu.subgroup_mma_store_matrix` operation stores a matrix collectively
+ using all the threads in a subgroup.
+
+ This operation takes a `!gpu.mma_matrix` and a memref as arguments.
+ `!gpu.mma_matrix` is the source which contains the data to be stored.
+ The destination can be in the global or shared memory space. The starting
+ of store address is determined using indices provided. The `leadDimension`
+ attribute specifies the leading dimension of the destination matrix.
+
+ This op is meant to be used along with `gpu.subgroup_mma_load_matrix` and
+ `gpu.subgroup_mma_compute`.
+
+ Example:
+
+ ```mlir
+ gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32} :
+ !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
+ ```
+ }];
+
+ let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
+ Arg<MemRefRankOf<[F16, F32], [2]>, "",[MemWrite]>:$dstMemref,
+ Variadic<Index>:$indices,
+ IndexAttr:$leadDimension);
+
+ let assemblyFormat = [{
+ $src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref)
+ }];
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
+
+ let summary = "GPU warp synchronous matrix multiply accumulate";
+
+ let description = [{
+ The `gpu.subgroup_mma_compute` operation performs a matrix-multiply accumulate(mma)
+ operation using all the threads in a subgroup.
+
+ This operation takes three `!gpu.mma_matrix`s as arguments. All of them hold `A`,
+ `B` and `C`operands for the mma operation. The operation performed is represented
+ as `D = A * B + C`. The op returns a `!gpu.mma_matrix` which contains the result of
+ the operation held by the current thread.
+
+ This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
+ `gpu.subgroup_mma_load_matrix`.
+
+ Example:
+
+ ```mlir
+ %D = gpu.subgroup_mma_compute_matrix %A, %B, %C :
+ !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">,
+ !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+ ```
+ }];
+
+ let arguments = (ins Arg<MMAMatrixOf<[F16]>>:$opA,
+ Arg<MMAMatrixOf<[F16]>>:$opB,
+ Arg<MMAMatrixOf<[F16, F32]>>:$opC);
+
+ let results = (outs GPU_MMAMatrix:$res);
+
+ let assemblyFormat = [{
+ $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB)`,` type($opC) `->` type($res)
+ }];
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
#endif // GPU_OPS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 203a0b2031c9f..a3f5a84dad59f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -151,4 +151,254 @@ def NVVM_MmaOp :
let verifier = [{ return ::verify(*this); }];
}
+// Base class for all the variants of WMMA loadOps that may be defined.
+class NVVM_WMMALoadOp<string mnemonic> : NVVM_Op<mnemonic>,
+ Results<(outs LLVM_AnyStruct:$res)>,
+ Arguments<(ins Variadic<LLVM_Type>:$args)> {
+
+ let summary = "Warp synchronous matrix load";
+
+ string baseDescription = [{"The `nvvm.wmma.m*n*k*.load.[a, b, c]` operation"
+ "loads a matrix collectively using all the threads in a warp."
+
+ "The operation takes two arguments, the address from where the matrix"
+ "elements are to be loaded from and a stride. The stride argument"
+ "represents the leading dimension of the source matrix. The address and"
+ "the stride are required to be the same across all threads in the warp."
+ "Each thread in a warp holds a certain number of elements. The Op returns"
+ "a LLVMStruct which holds the elements of the matrix held by this thread."
+
+ "This op is meant to be used along with `nvvm.wmma.m*n*k*.store` and"
+ "`nvvm.wmma.m*n*k*.mma`."}];
+
+ let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
+}
+
+def NVVM_WMMALoadAM16N16K16Op :
+ NVVM_WMMALoadOp<"wmma.m16n16k16.load.a.f16.row.stride">{
+
+ string llvmBuilder = [{
+ $res = createNvvmIntrinsicCall(
+ builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride, $args);
+ }];
+
+ string opDescription = [{
+ Example:
+
+ ```mlir
+ %2 = nvvm.wmma.m16n16k16.load.a %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
+ !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>,
+ vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
+ ```
+ }];
+
+ let description = !strconcat(baseDescription, opDescription);
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def NVVM_WMMALoadBM16N16K16Op :
+ NVVM_WMMALoadOp<"wmma.m16n16k16.load.b.f16.row.stride">{
+
+ string llvmBuilder = [{
+ $res = createNvvmIntrinsicCall(
+ builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride, $args);
+ }];
+
+ string opDescription = [{
+ Example:
+
+ ```mlir
+ %2 = nvvm.wmma.m16n16k16.load.b %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
+ !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>,
+ vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
+ ```
+ }];
+
+ let description = !strconcat(baseDescription, opDescription);
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def NVVM_WMMALoadCF16M16N16K16Op :
+ NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f16.row.stride">{
+ string llvmBuilder = [{
+ $res = createNvvmIntrinsicCall(
+ builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride, $args);
+ }];
+
+ string opDescription = [{
+ Example:
+
+ ```mlir
+ %2 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
+ !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
+ ```
+ }];
+
+ let description = !strconcat(baseDescription, opDescription);
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def NVVM_WMMALoadCF32M16N16K16Op :
+ NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f32.row.stride">{
+ string llvmBuilder = [{
+ $res = createNvvmIntrinsicCall(
+ builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride, $args);
+ }];
+
+ string opDescription = [{
+ Example:
+
+ ```mlir
+ %2 = nvvm.wmma.m16n16k16.load.c.f32.row.stride %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
+ !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ ```
+ }];
+
+ let description = !strconcat(baseDescription, opDescription);
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
+// Base class for all the variants of WMMA storeOps that may be defined.
+class NVVM_WMMAStoreOp<string mnemonic> : NVVM_Op<mnemonic>,
+ Arguments<(ins Variadic<LLVM_Type>:$args)>{
+ let summary = "Warp synchronous matrix store";
+
+ string baseDescription = [{
+ The `nvvm.wmma.m*n*k*.store` operation stores a matrix collectively using
+ all the threads in a warp.
+
+ The operation takes as arguments the address to where the matrix elements are
+ to be stored, a stride and the elements to store, held by the current thread.
+ The stride argument represents the leading dimension of the destination matrix.
+ The address and the stride are required to be the same across all threads in the
+ warp.
+
+ This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and
+ `nvvm.wmma.m16n16k16.mma`.
+ }];
+
+ let assemblyFormat = "$args attr-dict `:` type($args)";
+}
+
+def NVVM_WMMAStoreF16M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f16.row.stride"> {
+ string llvmBuilder = [{
+ createNvvmIntrinsicCall(
+ builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride, $args);
+ }];
+
+ string opDescription = [{
+ Example:
+
+ ```mlir
+ nvvm.wmma.m16n16k16.stored.f16.row.stride %0, %1, %2, %3, %4, %5, %6 : !llvm.ptr<i32, 3>,
+ !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>, !llvm.i32
+ ```
+ }];
+
+ let description = !strconcat(baseDescription, opDescription);
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def NVVM_WMMAStoreF32M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f32.row.stride"> {
+ string llvmBuilder = [{
+ createNvvmIntrinsicCall(
+ builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride, $args);
+ }];
+
+ string opDescription = [{
+ Example:
+
+ ```mlir
+ nvvm.wmma.m16n16k16.store.d.f32.row.stride %0, %1, %2, %3, %4, %5, %6, %7, %8, %9,
+ %10 : !llvm.ptr<i32, 3>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>,
+ !llvm.i32
+ ```
+ }];
+
+ let description = !strconcat(baseDescription, opDescription);
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
+// Base class for all the variants of WMMA mmaOps that may be defined.
+class NVVM_WMMAMmaOp<string mnemonic> : NVVM_Op<mnemonic>,
+ Results<(outs LLVM_AnyStruct:$res)>,
+ Arguments<(ins Variadic<LLVM_Type>:$args)>{
+ let summary = "Warp synchronous matrix-multiply accumulate using tensor cores.";
+
+ string baseDescription = [{
+ The `nvvm.wmma.m*n*k*.mma` operation performs a matrix-multiply accumulate
+ (mma) operation using all the threads in a warp.
+
+ The operation performed is represented as `D = A * B + C`. The operation takes
+ as arguments the elements of the matrices `A`, `B`, `C` and `D`, held by the
+ current thread. The op returns a LLVM struct which holds a part of the result
+ held by the current thread.
+
+ This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and `nvvm.wmma.
+ m16n16k16.store`.
+ }];
+}
+
+def NVVM_WMMAMmaF16F16M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f16.f16">{
+ string llvmBuilder = [{
+ $res = createNvvmIntrinsicCall(
+ builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16, $args);
+ }];
+
+ string opDescription = [{
+ Example:
+
+ ```mlir
+ %20 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %0, %1, %2, %3, %4, %5, %6, %7, %8,
+ %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19 : vector<2xf16> -> !llvm.struct
+ <(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ ```
+ }];
+
+ let parser = [{
+ return parseWMMAMmaF16F16M16N16K16Op(parser, result);
+ }];
+
+ let printer = [{
+ printWMMAMmaF16F16M16N16K16Op(p, *this);
+ }];
+
+ let description = !strconcat(baseDescription, opDescription);
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def NVVM_WMMAMmaF32F32M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f32.f32">{
+ string llvmBuilder = [{
+ $res = createNvvmIntrinsicCall(
+ builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32, $args);
+ }];
+
+ string opDescription = [{
+ Example:
+
+ ```mlir
+ %24 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %0, %1, %2, %3, %4, %5, %6, %7, %8
+ %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23 :
+ (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
+ vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
+ vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
+ vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32,
+ f32, f32, f32, f32, f32, f32, f32)>
+ ```
+ }];
+
+ let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
+
+ let description = !strconcat(baseDescription, opDescription);
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
#endif // NVVMIR_OPS
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index c6ceecd2b3ddd..137961443af33 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -257,6 +257,14 @@ llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,
llvm::Intrinsic::ID intrinsic,
ArrayRef<llvm::Value *> args = {},
ArrayRef<llvm::Type *> tys = {});
+
+/// Creates a call to an LLVM IR intrinsic function with the given arguments
+/// for NVVM WMMA ops. Handles cases where the intrinsic name is overloaded
+/// using the types of arguments supplied. Selects the correct intrinsic
+/// by inspecting the argument types.
+llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
+ llvm::Intrinsic::ID intrinsic,
+ ArrayRef<llvm::Value *> args = {});
} // namespace detail
} // namespace LLVM
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 27fde9b87405f..1fa687f83f0d9 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -28,10 +28,70 @@
using namespace mlir;
using namespace mlir::gpu;
+//===----------------------------------------------------------------------===//
+// MMAMatrixType
+//===----------------------------------------------------------------------===//
+
+MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType,
+ StringRef operand) {
+ return Base::get(elementType.getContext(), shape, elementType, operand);
+}
+
+MMAMatrixType
+MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType,
+ StringRef operand) {
+ return Base::getChecked(emitError, elementType.getContext(), shape,
+ elementType, operand);
+}
+
+unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
+
+ArrayRef<int64_t> MMAMatrixType::getShape() const {
+ return getImpl()->getShape();
+}
+
+Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
+
+StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
+
+bool MMAMatrixType::isValidElementType(Type elementType) {
+ return elementType.isF16() || elementType.isF32();
+}
+
+LogicalResult
+MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType,
+ StringRef operand) {
+ if (!operand.equals("AOp") && !operand.equals("BOp") &&
+ !operand.equals("COp") && !operand.equals("DOp"))
+ return emitError() << "operand expected to be one of AOp, BOp, COp or DOp";
+
+ if (shape.size() != 2)
+ return emitError() << "MMAMatrixType must have exactly two dimensions";
+
+ if (!MMAMatrixType::isValidElementType(elementType))
+ return emitError() << "MMAMatrixType elements must be F16 or F32";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// GPUDialect
//===----------------------------------------------------------------------===//
+/// GPU memory space identifiers.
+enum GPUMemorySpace {
+ /// Generic memory space identifier.
+ kGenericMemorySpace = 0,
+
+ /// Global memory space identifier.
+ kGlobalMemorySpace = 1,
+
+ /// Shared memory space identifier.
+ kSharedMemorySpace = 3
+};
+
bool GPUDialect::isKernel(Operation *op) {
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
return static_cast<bool>(isKernelAttr);
@@ -39,6 +99,7 @@ bool GPUDialect::isKernel(Operation *op) {
void GPUDialect::initialize() {
addTypes<AsyncTokenType>();
+ addTypes<MMAMatrixType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
@@ -56,6 +117,38 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
if (keyword == "async.token")
return AsyncTokenType::get(context);
+ if (keyword == "mma_matrix") {
+ llvm::SMLoc beginLoc = parser.getNameLoc();
+
+ // Parse '<'.
+ if (parser.parseLess())
+ return nullptr;
+
+ // Parse the size and elementType.
+ SmallVector<int64_t> shape;
+ Type elementType;
+ if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
+ parser.parseType(elementType))
+ return nullptr;
+
+ // Parse ','
+ if (parser.parseComma())
+ return nullptr;
+
+ // Parse operand.
+ StringRef operand;
+ if (failed(parser.parseOptionalString(&operand)))
+ return nullptr;
+
+ // Parse '>'.
+ if (parser.parseGreater())
+ return nullptr;
+
+ return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn(
+ parser.getEncodedSourceLoc(beginLoc)),
+ shape, elementType, operand);
+ }
+
parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
return Type();
}
@@ -63,6 +156,14 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<AsyncTokenType>([&](Type) { os << "async.token"; })
+ .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
+ os << "mma_matrix<";
+ auto shape = fragTy.getShape();
+ for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
+ os << *dim << 'x';
+ os << shape.back() << 'x' << fragTy.getElementType();
+ os << ", \"" << fragTy.getOperand() << "\"" << '>';
+ })
.Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
}
@@ -138,7 +239,8 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
return walkResult.wasInterrupted() ? failure() : success();
}
-template <typename T> static LogicalResult verifyIndexOp(T op) {
+template <typename T>
+static LogicalResult verifyIndexOp(T op) {
auto dimension = op.dimension();
if (dimension != "x" && dimension != "y" && dimension != "z")
return op.emitError("dimension \"") << dimension << "\" is invalid";
@@ -885,6 +987,95 @@ static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
printer << "]";
}
+//===----------------------------------------------------------------------===//
+// GPU_SubgroupMmaLoadMatrixOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(SubgroupMmaLoadMatrixOp op) {
+ auto srcType = op.srcMemref().getType();
+ auto resType = op.res().getType();
+ auto resMatrixType = resType.cast<gpu::MMAMatrixType>();
+ auto operand = resMatrixType.getOperand();
+ auto srcMemrefType = srcType.cast<MemRefType>();
+ auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
+
+ if (!srcMemrefType.getAffineMaps().empty() &&
+ !srcMemrefType.getAffineMaps().front().isIdentity())
+ return op.emitError("expected identity layout map for source memref");
+
+ if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
+ srcMemSpace != kGlobalMemorySpace)
+ return op.emitError(
+ "source memorySpace kGenericMemorySpace, kSharedMemorySpace or "
+ "kGlobalMemorySpace only allowed");
+
+ if (!operand.equals("AOp") && !operand.equals("BOp") &&
+ !operand.equals("COp"))
+ return op.emitError("only AOp, BOp and COp can be loaded");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GPU_SubgroupMmaStoreMatrixOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(SubgroupMmaStoreMatrixOp op) {
+ auto srcType = op.src().getType();
+ auto dstType = op.dstMemref().getType();
+ auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
+ auto dstMemrefType = dstType.cast<MemRefType>();
+ auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
+
+ if (!dstMemrefType.getAffineMaps().empty() &&
+ !dstMemrefType.getAffineMaps().front().isIdentity())
+ return op.emitError("expected identity layout map for destination memref");
+
+ if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&
+ dstMemSpace != kGlobalMemorySpace)
+ return op.emitError(
+ "destination memorySpace of kGenericMemorySpace, "
+ "kGlobalMemorySpace or kSharedMemorySpace only allowed");
+
+ if (!srcMatrixType.getOperand().equals("DOp"))
+ return op.emitError(
+ "expected the operand matrix being stored to have 'DOp' operand type");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GPU_SubgroupMmaComputeOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(SubgroupMmaComputeOp op) {
+ enum OperandMap { A, B, C };
+ SmallVector<MMAMatrixType, 3> opTypes;
+
+ auto populateOpInfo = [&opTypes, &op]() {
+ opTypes.push_back(op.opA().getType().cast<MMAMatrixType>());
+ opTypes.push_back(op.opB().getType().cast<MMAMatrixType>());
+ opTypes.push_back(op.opC().getType().cast<MMAMatrixType>());
+ };
+ populateOpInfo();
+
+ if (!opTypes[A].getOperand().equals("AOp") ||
+ !opTypes[B].getOperand().equals("BOp") ||
+ !opTypes[C].getOperand().equals("COp"))
+ return op.emitError("operands must be in the order AOp, BOp, COp");
+
+ ArrayRef<int64_t> aShape, bShape, cShape;
+ aShape = opTypes[A].getShape();
+ bShape = opTypes[B].getShape();
+ cShape = opTypes[C].getShape();
+
+ if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
+ bShape[1] != cShape[1])
+ return op.emitError("operand shapes do not satisfy matmul constraints");
+
+ return success();
+}
+
#include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc"
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3b6d2395d0ca5..c16e1c2f7af4f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -94,12 +94,12 @@ static LogicalResult verify(MmaOp op) {
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
- SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
- op.getOperandTypes().end());
- if (operand_types != SmallVector<Type, 8>(8, f16x2Ty) &&
- operand_types != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
- f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
- f32Ty, f32Ty, f32Ty}) {
+ SmallVector<Type, 12> operandTypes(op.getOperandTypes().begin(),
+ op.getOperandTypes().end());
+ if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
+ operandTypes != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
+ f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+ f32Ty, f32Ty, f32Ty}) {
return op.emitOpError(
"expected operands to be 4 <halfx2>s followed by either "
"4 <halfx2>s or 8 floats");
@@ -120,9 +120,9 @@ static LogicalResult verify(MmaOp op) {
"\"row\" or \"col\"");
}
- if (operand_types == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
- f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
- f32Ty, f32Ty, f32Ty} &&
+ if (operandTypes == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
+ f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+ f32Ty, f32Ty, f32Ty} &&
op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
blayout.getValue() == "col") {
return success();
@@ -130,6 +130,205 @@ static LogicalResult verify(MmaOp op) {
return op.emitOpError("unimplemented mma.sync variant");
}
+template <typename T>
+static LogicalResult verifyWMMALoadOp(T op, StringRef operand) {
+ MLIRContext *context = op.getContext();
+ auto i32Ty = IntegerType::get(context, 32);
+ auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
+ auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
+ auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
+ auto f16Ty = FloatType::getF16(context);
+ auto f32Ty = FloatType::getF32(context);
+ auto f16x2Ty = VectorType::get(2, f16Ty);
+ auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
+ context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+ auto f16x2x8StructTy = LLVM::LLVMStructType::getLiteral(
+ context,
+ {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+ auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
+
+ SmallVector<Type, 2> operandTypes(op.getOperandTypes().begin(),
+ op.getOperandTypes().end());
+ if (operandTypes != SmallVector<Type, 2>{i32Ptr1Ty, i32Ty} &&
+ operandTypes != SmallVector<Type, 2>{i32Ptr3Ty, i32Ty} &&
+ operandTypes != SmallVector<Type, 2>{i32Ptr0Ty, i32Ty}) {
+ return op.emitOpError("expected operands to be a source pointer in memory "
+ "space 0, 1, 3 followed by ldm of the source");
+ }
+
+ if (operand.equals("AOp") || operand.equals("BOp")) {
+ if (op.getType() != f16x2x8StructTy) {
+ return op.emitOpError("expected result type of loadAOp and loadBOp to be "
+ "a struct of 8 <halfx2>s");
+ }
+ } else if (operand.equals("COp")) {
+ if (op.getType() != f16x2x4StructTy && op.getType() != f32x8StructTy) {
+ return op.emitOpError("expected result type of loadCOp to be a struct of "
+ "4 <halfx2>s or 8 f32s");
+ }
+ }
+
+ return success();
+}
+
+static LogicalResult verify(WMMALoadAM16N16K16Op op) {
+ return verifyWMMALoadOp(op, "AOp");
+}
+
+static LogicalResult verify(WMMALoadBM16N16K16Op op) {
+ return verifyWMMALoadOp(op, "BOp");
+}
+
+static LogicalResult verify(WMMALoadCF16M16N16K16Op op) {
+ return verifyWMMALoadOp(op, "COp");
+}
+
+static LogicalResult verify(WMMALoadCF32M16N16K16Op op) {
+ return verifyWMMALoadOp(op, "COp");
+}
+
+template <typename T>
+static bool verifyWMMAStoreOp(T op, SmallVector<Type> &containedElems) {
+ SmallVector<Type> operandTypes(op.getOperandTypes().begin(),
+ op.getOperandTypes().end());
+ if (operandTypes == containedElems)
+ return true;
+
+ return false;
+}
+
+static LogicalResult verify(WMMAStoreF16M16N16K16Op op) {
+ MLIRContext *context = op.getContext();
+ auto i32Ty = IntegerType::get(context, 32);
+ auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
+ auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
+ auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
+ auto f16Ty = FloatType::getF16(context);
+ auto f16x2Ty = VectorType::get(2, f16Ty);
+ SmallVector<Type> type1{i32Ptr1Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
+ SmallVector<Type> type0{i32Ptr0Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
+ SmallVector<Type> type3{i32Ptr3Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
+ if (verifyWMMAStoreOp(op, type1) || verifyWMMAStoreOp(op, type0) ||
+ verifyWMMAStoreOp(op, type3))
+ return success();
+
+ return op.emitOpError("expected operands to be a source pointer in memory"
+ "space 0, 1, 3 followed by ldm of the source");
+}
+
+static LogicalResult verify(WMMAStoreF32M16N16K16Op op) {
+ MLIRContext *context = op.getContext();
+ auto i32Ty = IntegerType::get(context, 32);
+ auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
+ auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
+ auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
+ auto f32Ty = FloatType::getF32(context);
+
+ SmallVector<Type> type1{i32Ptr1Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+ f32Ty, f32Ty, f32Ty, f32Ty, i32Ty};
+ SmallVector<Type> type0{i32Ptr0Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+ f32Ty, f32Ty, f32Ty, f32Ty, i32Ty};
+ SmallVector<Type> type3{i32Ptr3Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+ f32Ty, f32Ty, f32Ty, f32Ty, i32Ty};
+ if (verifyWMMAStoreOp(op, type0) || verifyWMMAStoreOp(op, type1) ||
+ verifyWMMAStoreOp(op, type3))
+ return success();
+
+ return op.emitOpError("expected operands to be a source pointer in memory"
+ "space 0, 1, 3 followed by ldm of the source");
+}
+
+static LogicalResult verify(WMMAMmaF16F16M16N16K16Op op) {
+ MLIRContext *context = op.getContext();
+ auto f16Ty = FloatType::getF16(context);
+ auto f16x2Ty = VectorType::get(2, f16Ty);
+ auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
+ context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+
+ SmallVector<Type, 2> operandTypes(op.getOperandTypes().begin(),
+ op.getOperandTypes().end());
+ if (operandTypes != SmallVector<Type, 20>(20, f16x2Ty))
+ return op.emitOpError("expected 20 <halfx2>s as operands");
+
+ if (op.getResult().getType() != f16x2x4StructTy)
+ return op.emitOpError("expected result type to be a struct of 4 <halfx2>s");
+
+ return success();
+}
+
+static LogicalResult parseWMMAMmaF16F16M16N16K16Op(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ ::llvm::SMLoc operandsLoc;
+ Type operandType;
+ Type resType;
+
+ operandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(operands) ||
+ parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
+ parser.parseType(operandType) || parser.parseArrow())
+ return failure();
+
+ unsigned numOperands = operands.size();
+ SmallVector<Type> operandTypes(numOperands, operandType);
+ if (parser.parseType(resType))
+ return failure();
+ result.addTypes(resType);
+ if (parser.resolveOperands(operands, operandTypes, operandsLoc,
+ result.operands))
+ return failure();
+ return success();
+}
+
+static void printWMMAMmaF16F16M16N16K16Op(OpAsmPrinter &p,
+ WMMAMmaF16F16M16N16K16Op &op) {
+ p << op.getOperationName();
+ p << ' ';
+ p << op.args();
+ p.printOptionalAttrDict(op->getAttrs(), {});
+ p << " : ";
+ p << op->getOperand(0).getType();
+ p << ' ' << "->";
+ p << ' ';
+ p << ::llvm::ArrayRef<::mlir::Type>(op.res().getType());
+}
+
+static LogicalResult verify(WMMAMmaF32F32M16N16K16Op op) {
+ unsigned numABOperands = 16;
+ unsigned numCOperands = 8;
+ MLIRContext *context = op.getContext();
+ auto f16Ty = FloatType::getF16(context);
+ auto f32Ty = FloatType::getF32(context);
+ auto f16x2Ty = VectorType::get(2, f16Ty);
+ auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
+
+ SmallVector<Type> abOpTypes;
+ SmallVector<Type> bOpTypes;
+ SmallVector<Type> cOpTypes;
+
+ for (auto operand : op->getOperands().take_front(numABOperands)) {
+ abOpTypes.push_back(operand.getType());
+ }
+
+ for (auto operand :
+ op->getOperands().drop_front(numABOperands).take_front(numCOperands)) {
+ cOpTypes.push_back(operand.getType());
+ }
+
+ if (abOpTypes != SmallVector<Type>(16, f16x2Ty))
+ return op.emitOpError("expected 16 <halfx2>s for `a` and `b` operand");
+
+ if (cOpTypes != SmallVector<Type>(8, f32Ty))
+ return op.emitOpError("expected 8 f32s for `c` operand");
+
+ if (op.getResult().getType() != f32x8StructTy)
+ return op.emitOpError("expected result type to be a struct of 8 f32s");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
@@ -141,7 +340,8 @@ void NVVMDialect::initialize() {
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
>();
- // Support unknown operations because not all NVVM operations are registered.
+ // Support unknown operations because not all NVVM operations are
+ // registered.
allowUnknownOperations();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index bdcb451323add..7675387690a44 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -22,6 +22,7 @@
using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::detail::createIntrinsicCall;
+using mlir::LLVM::detail::createNvvmIntrinsicCall;
static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
bool withPredicate) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ea4c32cf6a2ed..fb5319546adc3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -35,6 +35,7 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
@@ -300,6 +301,29 @@ llvm::Value *mlir::LLVM::detail::createIntrinsicCall(
return builder.CreateCall(fn, args);
}
+llvm::Value *
+mlir::LLVM::detail::createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
+ llvm::Intrinsic::ID intrinsic,
+ ArrayRef<llvm::Value *> args) {
+ llvm::Module *module = builder.GetInsertBlock()->getModule();
+ llvm::Function *fn;
+ if (llvm::Intrinsic::isOverloaded(intrinsic)) {
+ if (intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16 &&
+ intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32) {
+ // NVVM load and store instrinsic names are overloaded on the
+ // source/destination pointer type. Pointer is the first argument in the
+ // corresponding NVVM Op.
+ fn = llvm::Intrinsic::getDeclaration(module, intrinsic,
+ {args[0]->getType()});
+ } else {
+ fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {});
+ }
+ } else {
+ fn = llvm::Intrinsic::getDeclaration(module, intrinsic);
+ }
+ return builder.CreateCall(fn, args);
+}
+
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`.
LogicalResult
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 4879c51479e06..58eca3b875855 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -458,3 +458,116 @@ func @memcpy_incompatible_shape(%dst : memref<7xf32>, %src : memref<9xf32>) {
// expected-error @+1 {{'gpu.memcpy' op arguments have incompatible shape}}
gpu.memcpy %dst, %src : memref<7xf32>, memref<9xf32>
}
+
+// -----
+
+func @mmamatrix_invalid_shape(){
+ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+ %i = constant 16 : index
+ // expected-error @+1 {{MMAMatrixType must have exactly two dimensions}}
+ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16x16xf16, "AOp">
+ return
+}
+
+// -----
+
+func @mmamatrix_operand_type(){
+ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+ %i = constant 16 : index
+ // expected-error @+1 {{operand expected to be one of AOp, BOp, COp or DOp}}
+ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "EOp">
+ return
+}
+
+// -----
+
+func @mmamatrix_invalid_element_type(){
+ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+ %i = constant 16 : index
+ // expected-error @+1 {{MMAMatrixType elements must be F16 or F32}}
+ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp">
+ return
+}
+
+// -----
+
+#layout_map_col_major = affine_map<(i, j) -> (j, i)>
+
+func @mmaLoadOp_identity_layout(){
+ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
+ %i = constant 16 : index
+ // expected-error @+1 {{expected identity layout map for source memref}}
+ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, #layout_map_col_major, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+ return
+}
+
+// -----
+
+func @mmaLoadOp_invalid_mem_space(){
+ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 5>
+ %i = constant 16 : index
+ // expected-error @+1 {{source memorySpace kGenericMemorySpace, kSharedMemorySpace or kGlobalMemorySpace only allowed}}
+ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 5> -> !gpu.mma_matrix<16x16xf16, "AOp">
+ return
+}
+
+// -----
+
+func @mmaLoadOp_operand_type(){
+ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+ %i = constant 16 : index
+ // expected-error @+1 {{only AOp, BOp and COp can be loaded}}
+ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "DOp">
+ return
+}
+
+// -----
+
+#layout_map_col_major = affine_map<(i, j) -> (j, i)>
+
+func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+ %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
+ %i = constant 16 : index
+ %j = constant 16 : index
+ // expected-error @+1 {{expected identity layout map for destination memref}}
+ gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16,#layout_map_col_major, 3>
+ return
+}
+
+// -----
+
+func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+ %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5>
+ %i = constant 16 : index
+ %j = constant 16 : index
+ // expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}}
+ gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 5>
+ return
+}
+
+// -----
+
+func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp">) -> () {
+ %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
+ %i = constant 16 : index
+ %j = constant 16 : index
+ // expected-error @+1 {{expected the operand matrix being stored to have 'DOp' operand type}}
+ gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "AOp">, memref<32x32xf16, 3>
+ return
+}
+
+// -----
+
+func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
+ // expected-error @+1 {{operands must be in the order AOp, BOp, COp}}
+ %D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+ return
+}
+
+// -----
+
+func @wmmaMmaOp_invalid_operand_shapes(%A : !gpu.mma_matrix<16x32xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
+ // expected-error @+1 {{operand shapes do not satisfy matmul constraints}}
+ %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+ return
+}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 7c1420115fc7b..a98fe1c496838 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -194,4 +194,15 @@ module attributes {gpu.container_module} {
%1 = gpu.memcpy async [%0] %dst, %src : memref<3x7xf32>, memref<3x7xf32, 1>
return
}
+
+ func @mmamatrix_valid_element_type(){
+ // CHECK-LABEL: func @mmamatrix_valid_element_type
+ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+ // CHECK: %[[wg:.*]] = memref.alloca()
+ %i = constant 16 : index
+ // CHECK: %[[i:.*]] = constant 16 : index
+ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+ // CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+ return
+ }
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 172a79d38c4d5..4c62f38a4af2c 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -843,3 +843,162 @@ module {
llvm.return
}
}
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_mem_space(%arg0: !llvm.ptr<i32, 5>, %arg1: i32) {
+ // expected-error at +1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}}
+ %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 5>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+ llvm.return
+}
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_missing_ldm(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+ // expected-error at +1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}}
+ %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0: (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+ llvm.return
+}
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+ // expected-error at +1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
+ %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+ llvm.return
+}
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+ // expected-error at +1 {{nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
+ %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+ llvm.return
+}
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_BOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+ // expected-error at +1 {{'nvvm.wmma.m16n16k16.load.b.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
+ %0 = nvvm.wmma.m16n16k16.load.b.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+ llvm.return
+}
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_COp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+ // expected-error at +1 {{'nvvm.wmma.m16n16k16.load.c.f16.row.stride' op expected result type of loadCOp to be a struct of 4 <halfx2>s or 8 f32s}}
+ %0 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+
+ llvm.return
+}
+
+// -----
+
+llvm.func @wmmaStoreOp_invalid_mem_space(%arg0: !llvm.ptr<i32, 5>, %arg1: vector<2 x f16>,
+ %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+ %arg4: vector<2 xf16>, %arg5: i32) {
+ // expected-error at +1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}}
+ nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr<i32, 5>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @wmmaStoreOp_invalid_missing_ldm(%arg0: !llvm.ptr<i32, 3>, %arg1: vector<2 x f16>,
+ %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+ %arg4: vector<2 xf16>, %arg5: i32) {
+ // expected-error at +1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}}
+ nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4 : !llvm.ptr<i32, 3>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @gpu_wmma_mma_op_invalid_operands(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+ %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+ %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+ %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+ %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+ %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+ %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+ %arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
+ %arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
+ %arg18: vector<2 x f16>) {
+ // expected-error at +1 {{'nvvm.wmma.m16n16k16.mma.row.row.f16.f16' op expected 20 <halfx2>s as operands}}
+ %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @gpu_wmma_mma_op_results(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+ %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+ %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+ %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+ %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+ %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+ %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+ %arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
+ %arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
+ %arg18: vector<2 x f16>, %arg19: vector<2 x f16>) {
+ // expected-error at +1 {{expected result type to be a struct of 4 <halfx2>s}}
+ %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @gpu_wmma_mma_op_invalid_ab_operands(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+ %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+ %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+ %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+ %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+ %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+ %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+ %arg14: vector<2 x f16>, %arg15: f32,
+ %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
+ %arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) {
+ // expected-error at +1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 16 <halfx2>s for `a` and `b` operand}}
+ %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @gpu_wmma_mma_op_invalid_c_operand(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+ %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+ %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+ %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+ %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+ %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+ %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+ %arg14: vector<2 x f16>, %arg15: vector<2xf16>,
+ %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
+ %arg20: f32, %arg21: f32, %arg22: f32, %arg23: vector<2xf16>) {
+ // expected-error at +1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 8 f32s for `c` operand}}
+ %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+ %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+ %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+ %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+ %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+ %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+ %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+ %arg14: vector<2 x f16>, %arg15: vector<2xf16>,
+ %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
+ %arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) {
+ // expected-error at +1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected result type to be a struct of 8 f32s}}
+ %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, vector<2xf16>)>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index c1b29fe515ee0..8fefd7866fd5d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -73,6 +73,43 @@ llvm.func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
}
+// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
+// in the LLVM NVPTX backend.
+llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+ // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}})
+ %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+ llvm.return
+}
+
+// The test below checks the correct mapping of the nvvm.wmma.*.store.* op to the correct intrinsic
+// in the LLVM NVPTX backend.
+llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr<i32, 3>, %arg1: vector<2 x f16>,
+ %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+ %arg4: vector<2 xf16>, %arg5: i32) {
+ // CHECK: call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, <2 x half> {{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 %{{.*}})
+ nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr<i32, 3>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32
+ llvm.return
+}
+
+// The test below checks the correct mapping of the nvvm.wmma.*.mma.* op to the correct intrinsic
+// in the LLVM NVPTX backend.
+llvm.func @gpu_wmma_mma_op(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+ %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+ %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+ %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+ %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+ %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+ %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+ %arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
+ %arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
+ %arg18: vector<2 x f16>, %arg19: vector<2 x f16>) {
+ // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
+ %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
+
+ llvm.return
+}
+
// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
llvm.func @kernel_func() attributes {nvvm.kernel} {
More information about the Mlir-commits
mailing list