[Mlir-commits] [mlir] 875eb52 - [MLIR][GPU][NVVM] Add warp synchronous matrix-multiply accumulate ops

Uday Bondhugula llvmlistbot at llvm.org
Wed May 5 23:37:45 PDT 2021


Author: Navdeep Kumar
Date: 2021-05-06T12:06:25+05:30
New Revision: 875eb523c13249114507cb8facd797773e278d9e

URL: https://github.com/llvm/llvm-project/commit/875eb523c13249114507cb8facd797773e278d9e
DIFF: https://github.com/llvm/llvm-project/commit/875eb523c13249114507cb8facd797773e278d9e.diff

LOG: [MLIR][GPU][NVVM] Add warp synchronous matrix-multiply accumulate ops

Add warp synchronous matrix-multiply accumulate ops in GPU and NVVM
dialect. Add following three ops to GPU dialect :-
  1.) subgroup_mma_load_matrix
  2.) subgroup_mma_store_matrix
  3.) subgroup_mma_compute
Add following three ops to NVVM dialect :-
  1.) wmma.m16n16k16.load.[a,b,c].[f16,f32].row.stride
  2.) wmma.m16n16k16.store.d.[f16,f32].row.stride
  3.) wmma.m16n16k16.mma.row.row.[f16,f32].[f16,f32]

Reviewed By: bondhugula, ftynse, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D95330

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUBase.td
    mlir/include/mlir/Dialect/GPU/GPUDialect.h
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/GPU/ops.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td
index 63cc7df9ed0d5..3945c3b8c47e3 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td
@@ -57,6 +57,17 @@ def GPU_AsyncToken : DialectType<
   GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::AsyncTokenType>()">, "async token type">,
              BuildableType<"mlir::gpu::AsyncTokenType::get($_builder.getContext())">;
 
+// Predicat to check if type is gpu::MMAMatrixType.
+def IsMMAMatrixTypePred : CPred<"$_self.isa<::mlir::gpu::MMAMatrixType>()">;
+
+def GPU_MMAMatrix : DialectType<
+  GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
+
+class MMAMatrixOf<list<Type> allowedTypes> :
+  ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
+  "$_self.cast<::mlir::gpu::MMAMatrixType>().getElementType()",
+  "gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
+
 def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
   let description = [{
     Interface for GPU operations that execute asynchronously on the device.
@@ -102,4 +113,18 @@ def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
   ];
 }
 
+// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing
+// the layouts of the operands supported by the ops that use this attribute.
+def RowMajor: StrEnumAttrCase<"RowMajor", 0>;
+def ColMajor: StrEnumAttrCase<"ColMajor", 1>;
+
+// Specifies a String enum Attribute for Warp wide matrix operations,
+// representing the layout of respective operands. The layout later governs
+// the lowerings to appropriate intrinsics.
+def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major",
+                           [RowMajor, ColMajor]> {
+  let stringToSymbolFnName = "LayoutStrToEnum";
+  let symbolToStringFnName = "EnumToLayoutStr";
+}
+
 #endif // GPU_BASE

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
index 26ab171727146..7832d48945656 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -44,6 +44,122 @@ class AsyncTokenType
   using Base::Base;
 };
 
+/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
+/// and type.
+struct MMAMatrixStorageType : public TypeStorage {
+  MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes,
+                       Type elementType, StringRef operand)
+      : dimShapes(dimShapes), numDims(numDims), elementType(elementType),
+        operand(operand) {}
+
+  /// The hash key for uniquing.
+  using KeyTy = std::tuple<ArrayRef<int64_t>, Type, StringRef>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(getShape(), elementType, operand);
+  }
+
+  /// Construction.
+  static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator,
+                                         const KeyTy &key) {
+    ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
+    StringRef operand = allocator.copyInto(std::get<2>(key));
+
+    return new (allocator.allocate<MMAMatrixStorageType>())
+        MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(key),
+                             operand);
+  }
+
+  ArrayRef<int64_t> getShape() const {
+    return ArrayRef<int64_t>(dimShapes, numDims);
+  }
+
+  StringRef getOperand() const { return operand; }
+
+  /// Reference to the shape of the MMA matrix.
+  const int64_t *dimShapes;
+
+  /// Number of dimensions in the MMA matrix.
+  unsigned numDims;
+
+  /// Element type of elements held in the MMA matrix.
+  Type elementType;
+
+  /// MMA operand that this MMAMatrix holds. The general form of operation this
+  /// type supports is given by the equation D = (alpha*(A*B)) + (beta*C). This
+  /// field specifies which operand in the given equation is held by this type.
+  /// The valid values are "AOp", "BOp", "COp" and "DOp".
+  StringRef operand;
+};
+
+/// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply
+/// accumulate operations. MMAMatrices are taken as direct operands by these
+/// operations and are also produced as results. These matrices are meant to
+/// reside in the registers. A limited number of pointwise operations can be
+/// performed on these matrices, i.e., operations which operate uniformly on
+/// all the elements in the matrix and do not change the order of matrix
+/// elements. The above conditions exist because the layout of matrix elements
+/// inside the matrix is opaque i.e., the elements may be present in the
+/// matrix in any order. The general usage of this type is shown as follows:-
+///
+///   %0 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {leadDimension = 16 :
+///           index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+///
+/// The MMAMatrixType describes the shape of the matrix being loaded and the
+/// operand being loaded too. The operand needs to be specified to aid the
+/// lowering of this type to dialects such as NVVM where each workitem may
+/// hold 
diff erent amount of elements depending on the elementType of the
+/// matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type
+/// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
+/// are:-
+///
+///   %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16,
+///   "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32,
+///                             "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
+///
+///
+///   gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
+///           : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
+// TODO: consider moving this to ODS.
+class MMAMatrixType
+    : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
+public:
+  using Base::Base;
+
+  /// Get MMAMatrixType and verify construction Invariants.
+  static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
+                           StringRef operand);
+
+  /// Get MMAMatrixType at a particular location and verify construction
+  /// Invariants.
+  static MMAMatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                  ArrayRef<int64_t> shape, Type elementType,
+                                  StringRef operand);
+
+  /// Check if a type is valid a MMAMatrixType elementType.
+  static bool isValidElementType(Type elementType);
+
+  /// Verify that shape and elementType are actually allowed for the
+  /// MMAMatrixType.
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              ArrayRef<int64_t> shape, Type elementType,
+                              StringRef operand);
+
+  /// Get number of dims.
+  unsigned getNumDims() const;
+
+  /// Get shape of the matrix.
+  ArrayRef<int64_t> getShape() const;
+
+  /// Get elementType of a single element.
+  Type getElementType() const;
+
+  /// The general form of operation this type supports is given by the equation
+  /// D = (alpha*(A*B)) + (beta*C). This function returns which operand in the
+  /// given equation is held by this type. String returned can be one of"AOp",
+  /// "BOp", "COp" and "DOp".
+  StringRef getOperand() const;
+};
+
 // Adds a `gpu.async.token` to the front of the argument list.
 void addAsyncDependency(Operation *op, Value token);
 

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 41206af46ae2d..5bd6956e75c0c 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -912,4 +912,122 @@ def GPU_MemcpyOp : GPU_Op<"memcpy", [GPU_AsyncOpInterface]> {
   let verifier = [{ return ::verify(*this); }];
 }
 
+def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
+    [MemoryEffects<[MemRead]>]>{
+
+  let summary = "GPU warp synchronous matrix load";
+
+  let description = [{
+    The `gpu.subgroup_mma_load_matrix` operation loads a matrix collectively
+    using all the threads in a subgroup.
+
+    This operation takes a memref as argument. It is the source matrix from which
+    data is to be loaded. The op returns a `!gpu.mma_matrix`. The source memref
+    can be in the global or shared memory space. The starting of the load address
+    is determined using indices provided. The matrix being loaded is specified in
+    the result type. This attribute is necessary because there exists a 
diff erent
+    LLVM intrinsic for loading each operand, This is probably because all operands
+    need to be laid out in a specific/
diff erent way for the operation in the registers.
+    `leadDimension` attribute specifies the leading dimension of the source matrix.
+
+    This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
+    `gpu.subgroup_mma_compute`.
+
+    Example:
+
+    ```mlir
+     %0 = gpu.subgroup_mma_load_matrix src[%i,%j] : {leadDimension = 32
+    : i32} : memref<32x32xf16, 3>, !gpu.mma_matrix<16x16xf16, "AOp">
+    ```
+  }];
+
+  let arguments = (ins Arg<MemRefRankOf<[F16, F32], [2]>, "", [MemRead]>:$srcMemref,
+                  Variadic<Index>:$indices,
+                  IndexAttr:$leadDimension);
+
+  let results = (outs GPU_MMAMatrix:$res);
+
+  let assemblyFormat = [{
+    $srcMemref`[`$indices`]` attr-dict `:` type($srcMemref) `->` type($res)
+  }];
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
+    [MemoryEffects<[MemWrite]>]>{
+
+  let summary = "GPU warp synchronous matrix store";
+
+  let description = [{
+    The `gpu.subgroup_mma_store_matrix` operation stores a matrix collectively
+    using all the threads in a subgroup.
+
+    This operation takes a `!gpu.mma_matrix` and a memref as arguments.
+    `!gpu.mma_matrix` is the source which contains the data to be stored.
+    The destination can be in the global or shared memory space. The starting
+    of store address is determined using indices provided. The `leadDimension`
+    attribute specifies the leading dimension of the destination matrix.
+
+    This op is meant to be used along with `gpu.subgroup_mma_load_matrix` and
+    `gpu.subgroup_mma_compute`.
+
+    Example:
+
+    ```mlir
+    gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32} :
+    !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
+    ```
+  }];
+
+  let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
+                  Arg<MemRefRankOf<[F16, F32], [2]>, "",[MemWrite]>:$dstMemref,
+                  Variadic<Index>:$indices,
+                  IndexAttr:$leadDimension);
+
+  let assemblyFormat = [{
+    $src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref)
+  }];
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
+
+  let summary = "GPU warp synchronous matrix multiply accumulate";
+
+  let description = [{
+    The `gpu.subgroup_mma_compute` operation performs a matrix-multiply accumulate(mma)
+    operation using all the threads in a subgroup.
+
+    This operation takes three `!gpu.mma_matrix`s as arguments. All of them hold `A`,
+     `B` and `C`operands for the mma operation. The operation performed is represented
+    as `D = A * B + C`. The op returns a `!gpu.mma_matrix` which contains the result of
+    the operation held by the current thread.
+
+    This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
+    `gpu.subgroup_mma_load_matrix`.
+
+    Example:
+
+    ```mlir
+    %D = gpu.subgroup_mma_compute_matrix %A, %B, %C :
+    !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">,
+    !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+    ```
+  }];
+
+  let arguments = (ins Arg<MMAMatrixOf<[F16]>>:$opA,
+                  Arg<MMAMatrixOf<[F16]>>:$opB,
+                  Arg<MMAMatrixOf<[F16, F32]>>:$opC);
+
+  let results = (outs GPU_MMAMatrix:$res);
+
+  let assemblyFormat = [{
+    $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB)`,` type($opC) `->` type($res)
+  }];
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
 #endif // GPU_OPS

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 203a0b2031c9f..a3f5a84dad59f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -151,4 +151,254 @@ def NVVM_MmaOp :
   let verifier = [{ return ::verify(*this); }];
 }
 
+// Base class for all the variants of WMMA loadOps that may be defined.
+class NVVM_WMMALoadOp<string mnemonic> : NVVM_Op<mnemonic>,
+  Results<(outs LLVM_AnyStruct:$res)>,
+  Arguments<(ins Variadic<LLVM_Type>:$args)> {
+
+  let summary = "Warp synchronous matrix load";
+
+  string baseDescription = [{"The `nvvm.wmma.m*n*k*.load.[a, b, c]` operation"
+    "loads a matrix collectively using all the threads in a warp."
+
+    "The operation takes two arguments, the address from where the matrix"
+    "elements are to be loaded from and a stride. The stride argument"
+    "represents the leading dimension of the source matrix. The address and"
+    "the stride are required to be the same across all threads in the warp."
+    "Each thread in a warp holds a certain number of elements. The Op returns"
+    "a LLVMStruct which holds the elements of the matrix held by this thread."
+
+    "This op is meant to be used along with `nvvm.wmma.m*n*k*.store` and"
+    "`nvvm.wmma.m*n*k*.mma`."}];
+
+  let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
+}
+
+def NVVM_WMMALoadAM16N16K16Op :
+  NVVM_WMMALoadOp<"wmma.m16n16k16.load.a.f16.row.stride">{
+
+  string llvmBuilder = [{
+    $res = createNvvmIntrinsicCall(
+      builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride, $args);
+  }];
+
+  string opDescription = [{
+    Example:
+
+    ```mlir
+    %2 = nvvm.wmma.m16n16k16.load.a %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
+    !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>,
+    vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
+    ```
+  }];
+
+  let description = !strconcat(baseDescription, opDescription);
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def NVVM_WMMALoadBM16N16K16Op :
+  NVVM_WMMALoadOp<"wmma.m16n16k16.load.b.f16.row.stride">{
+
+  string llvmBuilder = [{
+    $res = createNvvmIntrinsicCall(
+      builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride, $args);
+  }];
+
+  string opDescription = [{
+    Example:
+
+    ```mlir
+    %2 = nvvm.wmma.m16n16k16.load.b %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
+    !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>,
+    vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
+    ```
+  }];
+
+  let description = !strconcat(baseDescription, opDescription);
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def NVVM_WMMALoadCF16M16N16K16Op :
+  NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f16.row.stride">{
+  string llvmBuilder = [{
+    $res = createNvvmIntrinsicCall(
+      builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride, $args);
+  }];
+
+  string opDescription = [{
+    Example:
+
+    ```mlir
+    %2 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
+    !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>
+    ```
+  }];
+
+  let description = !strconcat(baseDescription, opDescription);
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def NVVM_WMMALoadCF32M16N16K16Op :
+  NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f32.row.stride">{
+  string llvmBuilder = [{
+    $res = createNvvmIntrinsicCall(
+      builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride, $args);
+  }];
+
+  string opDescription = [{
+    Example:
+
+    ```mlir
+    %2 = nvvm.wmma.m16n16k16.load.c.f32.row.stride %0, %1 : !llvm.ptr<i32, 3>, !llvm.i32 ->
+    !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+    ```
+  }];
+
+  let description = !strconcat(baseDescription, opDescription);
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
+// Base class for all the variants of WMMA storeOps that may be defined.
+class NVVM_WMMAStoreOp<string mnemonic> : NVVM_Op<mnemonic>,
+  Arguments<(ins Variadic<LLVM_Type>:$args)>{
+  let summary = "Warp synchronous matrix store";
+
+  string baseDescription = [{
+    The `nvvm.wmma.m*n*k*.store` operation stores a matrix collectively using
+    all the threads in a warp.
+
+    The operation takes as arguments the address to where the matrix elements are
+    to be stored, a stride and the elements to store, held by the current thread.
+    The stride argument represents the leading dimension of the destination matrix.
+    The address and the stride are required to be the same across all threads in the
+    warp.
+
+    This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and
+    `nvvm.wmma.m16n16k16.mma`.
+  }];
+
+  let assemblyFormat = "$args attr-dict `:` type($args)";
+}
+
+def NVVM_WMMAStoreF16M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f16.row.stride"> {
+  string llvmBuilder = [{
+        createNvvmIntrinsicCall(
+        builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride, $args);
+  }];
+
+  string opDescription = [{
+    Example:
+
+    ```mlir
+    nvvm.wmma.m16n16k16.stored.f16.row.stride %0, %1, %2, %3, %4, %5, %6 : !llvm.ptr<i32, 3>,
+    !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>, !llvm.i32
+    ```
+  }];
+
+  let description = !strconcat(baseDescription, opDescription);
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def NVVM_WMMAStoreF32M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f32.row.stride"> {
+  string llvmBuilder = [{
+        createNvvmIntrinsicCall(
+        builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride, $args);
+  }];
+
+  string opDescription = [{
+    Example:
+
+    ```mlir
+    nvvm.wmma.m16n16k16.store.d.f32.row.stride %0, %1, %2, %3, %4, %5, %6, %7, %8, %9,
+    %10 : !llvm.ptr<i32, 3>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>,
+    !llvm.i32
+    ```
+  }];
+
+  let description = !strconcat(baseDescription, opDescription);
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
+// Base class for all the variants of WMMA mmaOps that may be defined.
+class NVVM_WMMAMmaOp<string mnemonic> : NVVM_Op<mnemonic>,
+  Results<(outs LLVM_AnyStruct:$res)>,
+  Arguments<(ins Variadic<LLVM_Type>:$args)>{
+  let summary = "Warp synchronous matrix-multiply accumulate using tensor cores.";
+
+  string baseDescription = [{
+    The `nvvm.wmma.m*n*k*.mma` operation performs a matrix-multiply accumulate
+    (mma) operation using all the threads in a warp.
+
+    The operation performed is represented as `D = A * B + C`. The operation takes
+    as arguments the elements of the matrices `A`, `B`, `C` and `D`, held by the
+    current thread. The op returns a LLVM struct which holds a part of the result
+    held by the current thread.
+
+    This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and `nvvm.wmma.
+    m16n16k16.store`.
+  }];
+}
+
+def NVVM_WMMAMmaF16F16M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f16.f16">{
+  string llvmBuilder = [{
+    $res = createNvvmIntrinsicCall(
+        builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16, $args);
+  }];
+
+  string opDescription = [{
+    Example:
+
+    ```mlir
+    %20 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %0, %1, %2, %3, %4, %5, %6, %7, %8,
+    %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19 : vector<2xf16> -> !llvm.struct
+    <(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    ```
+  }];
+
+  let parser = [{
+    return parseWMMAMmaF16F16M16N16K16Op(parser, result);
+  }];
+
+  let printer = [{
+    printWMMAMmaF16F16M16N16K16Op(p, *this);
+  }];
+
+  let description = !strconcat(baseDescription, opDescription);
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def NVVM_WMMAMmaF32F32M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f32.f32">{
+  string llvmBuilder = [{
+    $res = createNvvmIntrinsicCall(
+        builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32, $args);
+  }];
+
+  string opDescription = [{
+    Example:
+
+    ```mlir
+    %24 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %0, %1, %2, %3, %4, %5, %6, %7, %8
+    %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23 :
+    (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
+    vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
+    vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>,
+    vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32,
+    f32, f32, f32, f32, f32, f32, f32)>
+    ```
+  }];
+
+  let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
+
+  let description = !strconcat(baseDescription, opDescription);
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
 #endif // NVVMIR_OPS

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index c6ceecd2b3ddd..137961443af33 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -257,6 +257,14 @@ llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,
                                  llvm::Intrinsic::ID intrinsic,
                                  ArrayRef<llvm::Value *> args = {},
                                  ArrayRef<llvm::Type *> tys = {});
+
+/// Creates a call to an LLVM IR intrinsic function with the given arguments
+/// for NVVM WMMA ops. Handles cases where the intrinsic name is overloaded
+/// using the types of arguments supplied. Selects the correct intrinsic
+/// by inspecting the argument types.
+llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
+                                     llvm::Intrinsic::ID intrinsic,
+                                     ArrayRef<llvm::Value *> args = {});
 } // namespace detail
 
 } // namespace LLVM

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 27fde9b87405f..1fa687f83f0d9 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -28,10 +28,70 @@
 using namespace mlir;
 using namespace mlir::gpu;
 
+//===----------------------------------------------------------------------===//
+// MMAMatrixType
+//===----------------------------------------------------------------------===//
+
+MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType,
+                                 StringRef operand) {
+  return Base::get(elementType.getContext(), shape, elementType, operand);
+}
+
+MMAMatrixType
+MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                          ArrayRef<int64_t> shape, Type elementType,
+                          StringRef operand) {
+  return Base::getChecked(emitError, elementType.getContext(), shape,
+                          elementType, operand);
+}
+
+unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
+
+ArrayRef<int64_t> MMAMatrixType::getShape() const {
+  return getImpl()->getShape();
+}
+
+Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
+
+StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
+
+bool MMAMatrixType::isValidElementType(Type elementType) {
+  return elementType.isF16() || elementType.isF32();
+}
+
+LogicalResult
+MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
+                      ArrayRef<int64_t> shape, Type elementType,
+                      StringRef operand) {
+  if (!operand.equals("AOp") && !operand.equals("BOp") &&
+      !operand.equals("COp") && !operand.equals("DOp"))
+    return emitError() << "operand expected to be one of AOp, BOp, COp or DOp";
+
+  if (shape.size() != 2)
+    return emitError() << "MMAMatrixType must have exactly two dimensions";
+
+  if (!MMAMatrixType::isValidElementType(elementType))
+    return emitError() << "MMAMatrixType elements must be F16 or F32";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GPUDialect
 //===----------------------------------------------------------------------===//
 
+/// GPU memory space identifiers.
+enum GPUMemorySpace {
+  /// Generic memory space identifier.
+  kGenericMemorySpace = 0,
+
+  /// Global memory space identifier.
+  kGlobalMemorySpace = 1,
+
+  /// Shared memory space identifier.
+  kSharedMemorySpace = 3
+};
+
 bool GPUDialect::isKernel(Operation *op) {
   UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
   return static_cast<bool>(isKernelAttr);
@@ -39,6 +99,7 @@ bool GPUDialect::isKernel(Operation *op) {
 
 void GPUDialect::initialize() {
   addTypes<AsyncTokenType>();
+  addTypes<MMAMatrixType>();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
@@ -56,6 +117,38 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
   if (keyword == "async.token")
     return AsyncTokenType::get(context);
 
+  if (keyword == "mma_matrix") {
+    llvm::SMLoc beginLoc = parser.getNameLoc();
+
+    // Parse '<'.
+    if (parser.parseLess())
+      return nullptr;
+
+    // Parse the size and elementType.
+    SmallVector<int64_t> shape;
+    Type elementType;
+    if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
+        parser.parseType(elementType))
+      return nullptr;
+
+    // Parse ','
+    if (parser.parseComma())
+      return nullptr;
+
+    // Parse operand.
+    StringRef operand;
+    if (failed(parser.parseOptionalString(&operand)))
+      return nullptr;
+
+    // Parse '>'.
+    if (parser.parseGreater())
+      return nullptr;
+
+    return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn(
+                                         parser.getEncodedSourceLoc(beginLoc)),
+                                     shape, elementType, operand);
+  }
+
   parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
   return Type();
 }
@@ -63,6 +156,14 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
       .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
+      .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
+        os << "mma_matrix<";
+        auto shape = fragTy.getShape();
+        for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
+          os << *dim << 'x';
+        os << shape.back() << 'x' << fragTy.getElementType();
+        os << ", \"" << fragTy.getOperand() << "\"" << '>';
+      })
       .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
 }
 
@@ -138,7 +239,8 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
   return walkResult.wasInterrupted() ? failure() : success();
 }
 
-template <typename T> static LogicalResult verifyIndexOp(T op) {
+template <typename T>
+static LogicalResult verifyIndexOp(T op) {
   auto dimension = op.dimension();
   if (dimension != "x" && dimension != "y" && dimension != "z")
     return op.emitError("dimension \"") << dimension << "\" is invalid";
@@ -885,6 +987,95 @@ static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
   printer << "]";
 }
 
+//===----------------------------------------------------------------------===//
+// GPU_SubgroupMmaLoadMatrixOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(SubgroupMmaLoadMatrixOp op) {
+  auto srcType = op.srcMemref().getType();
+  auto resType = op.res().getType();
+  auto resMatrixType = resType.cast<gpu::MMAMatrixType>();
+  auto operand = resMatrixType.getOperand();
+  auto srcMemrefType = srcType.cast<MemRefType>();
+  auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
+
+  if (!srcMemrefType.getAffineMaps().empty() &&
+      !srcMemrefType.getAffineMaps().front().isIdentity())
+    return op.emitError("expected identity layout map for source memref");
+
+  if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
+      srcMemSpace != kGlobalMemorySpace)
+    return op.emitError(
+        "source memorySpace kGenericMemorySpace, kSharedMemorySpace or "
+        "kGlobalMemorySpace only allowed");
+
+  if (!operand.equals("AOp") && !operand.equals("BOp") &&
+      !operand.equals("COp"))
+    return op.emitError("only AOp, BOp and COp can be loaded");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GPU_SubgroupMmaStoreMatrixOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(SubgroupMmaStoreMatrixOp op) {
+  auto srcType = op.src().getType();
+  auto dstType = op.dstMemref().getType();
+  auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
+  auto dstMemrefType = dstType.cast<MemRefType>();
+  auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
+
+  if (!dstMemrefType.getAffineMaps().empty() &&
+      !dstMemrefType.getAffineMaps().front().isIdentity())
+    return op.emitError("expected identity layout map for destination memref");
+
+  if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&
+      dstMemSpace != kGlobalMemorySpace)
+    return op.emitError(
+        "destination memorySpace of kGenericMemorySpace, "
+        "kGlobalMemorySpace or kSharedMemorySpace only allowed");
+
+  if (!srcMatrixType.getOperand().equals("DOp"))
+    return op.emitError(
+        "expected the operand matrix being stored to have 'DOp' operand type");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GPU_SubgroupMmaComputeOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(SubgroupMmaComputeOp op) {
+  enum OperandMap { A, B, C };
+  SmallVector<MMAMatrixType, 3> opTypes;
+
+  auto populateOpInfo = [&opTypes, &op]() {
+    opTypes.push_back(op.opA().getType().cast<MMAMatrixType>());
+    opTypes.push_back(op.opB().getType().cast<MMAMatrixType>());
+    opTypes.push_back(op.opC().getType().cast<MMAMatrixType>());
+  };
+  populateOpInfo();
+
+  if (!opTypes[A].getOperand().equals("AOp") ||
+      !opTypes[B].getOperand().equals("BOp") ||
+      !opTypes[C].getOperand().equals("COp"))
+    return op.emitError("operands must be in the order AOp, BOp, COp");
+
+  ArrayRef<int64_t> aShape, bShape, cShape;
+  aShape = opTypes[A].getShape();
+  bShape = opTypes[B].getShape();
+  cShape = opTypes[C].getShape();
+
+  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
+      bShape[1] != cShape[1])
+    return op.emitError("operand shapes do not satisfy matmul constraints");
+
+  return success();
+}
+
 #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc"
 
 #define GET_OP_CLASSES

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3b6d2395d0ca5..c16e1c2f7af4f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -94,12 +94,12 @@ static LogicalResult verify(MmaOp op) {
   auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
       context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
 
-  SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
-                                      op.getOperandTypes().end());
-  if (operand_types != SmallVector<Type, 8>(8, f16x2Ty) &&
-      operand_types != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
-                                             f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
-                                             f32Ty, f32Ty, f32Ty}) {
+  SmallVector<Type, 12> operandTypes(op.getOperandTypes().begin(),
+                                     op.getOperandTypes().end());
+  if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
+      operandTypes != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
+                                            f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+                                            f32Ty, f32Ty, f32Ty}) {
     return op.emitOpError(
         "expected operands to be 4 <halfx2>s followed by either "
         "4 <halfx2>s or 8 floats");
@@ -120,9 +120,9 @@ static LogicalResult verify(MmaOp op) {
         "\"row\" or \"col\"");
   }
 
-  if (operand_types == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
-                                             f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
-                                             f32Ty, f32Ty, f32Ty} &&
+  if (operandTypes == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
+                                            f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+                                            f32Ty, f32Ty, f32Ty} &&
       op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
       blayout.getValue() == "col") {
     return success();
@@ -130,6 +130,205 @@ static LogicalResult verify(MmaOp op) {
   return op.emitOpError("unimplemented mma.sync variant");
 }
 
+template <typename T>
+static LogicalResult verifyWMMALoadOp(T op, StringRef operand) {
+  MLIRContext *context = op.getContext();
+  auto i32Ty = IntegerType::get(context, 32);
+  auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
+  auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
+  auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
+  auto f16Ty = FloatType::getF16(context);
+  auto f32Ty = FloatType::getF32(context);
+  auto f16x2Ty = VectorType::get(2, f16Ty);
+  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
+      context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+  auto f16x2x8StructTy = LLVM::LLVMStructType::getLiteral(
+      context,
+      {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+  auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
+      context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
+
+  SmallVector<Type, 2> operandTypes(op.getOperandTypes().begin(),
+                                    op.getOperandTypes().end());
+  if (operandTypes != SmallVector<Type, 2>{i32Ptr1Ty, i32Ty} &&
+      operandTypes != SmallVector<Type, 2>{i32Ptr3Ty, i32Ty} &&
+      operandTypes != SmallVector<Type, 2>{i32Ptr0Ty, i32Ty}) {
+    return op.emitOpError("expected operands to be a source pointer in memory "
+                          "space 0, 1, 3 followed by ldm of the source");
+  }
+
+  if (operand.equals("AOp") || operand.equals("BOp")) {
+    if (op.getType() != f16x2x8StructTy) {
+      return op.emitOpError("expected result type of loadAOp and loadBOp to be "
+                            "a struct of 8 <halfx2>s");
+    }
+  } else if (operand.equals("COp")) {
+    if (op.getType() != f16x2x4StructTy && op.getType() != f32x8StructTy) {
+      return op.emitOpError("expected result type of loadCOp to be a struct of "
+                            "4 <halfx2>s or 8 f32s");
+    }
+  }
+
+  return success();
+}
+
+static LogicalResult verify(WMMALoadAM16N16K16Op op) {
+  return verifyWMMALoadOp(op, "AOp");
+}
+
+static LogicalResult verify(WMMALoadBM16N16K16Op op) {
+  return verifyWMMALoadOp(op, "BOp");
+}
+
+static LogicalResult verify(WMMALoadCF16M16N16K16Op op) {
+  return verifyWMMALoadOp(op, "COp");
+}
+
+static LogicalResult verify(WMMALoadCF32M16N16K16Op op) {
+  return verifyWMMALoadOp(op, "COp");
+}
+
+template <typename T>
+static bool verifyWMMAStoreOp(T op, SmallVector<Type> &containedElems) {
+  SmallVector<Type> operandTypes(op.getOperandTypes().begin(),
+                                 op.getOperandTypes().end());
+  if (operandTypes == containedElems)
+    return true;
+
+  return false;
+}
+
+static LogicalResult verify(WMMAStoreF16M16N16K16Op op) {
+  MLIRContext *context = op.getContext();
+  auto i32Ty = IntegerType::get(context, 32);
+  auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
+  auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
+  auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
+  auto f16Ty = FloatType::getF16(context);
+  auto f16x2Ty = VectorType::get(2, f16Ty);
+  SmallVector<Type> type1{i32Ptr1Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
+  SmallVector<Type> type0{i32Ptr0Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
+  SmallVector<Type> type3{i32Ptr3Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty};
+  if (verifyWMMAStoreOp(op, type1) || verifyWMMAStoreOp(op, type0) ||
+      verifyWMMAStoreOp(op, type3))
+    return success();
+
+  return op.emitOpError("expected operands to be a source pointer in memory"
+                        "space 0, 1, 3 followed by ldm of the source");
+}
+
+static LogicalResult verify(WMMAStoreF32M16N16K16Op op) {
+  MLIRContext *context = op.getContext();
+  auto i32Ty = IntegerType::get(context, 32);
+  auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1);
+  auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3);
+  auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0);
+  auto f32Ty = FloatType::getF32(context);
+
+  SmallVector<Type> type1{i32Ptr1Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+                          f32Ty,     f32Ty, f32Ty, f32Ty, i32Ty};
+  SmallVector<Type> type0{i32Ptr0Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+                          f32Ty,     f32Ty, f32Ty, f32Ty, i32Ty};
+  SmallVector<Type> type3{i32Ptr3Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+                          f32Ty,     f32Ty, f32Ty, f32Ty, i32Ty};
+  if (verifyWMMAStoreOp(op, type0) || verifyWMMAStoreOp(op, type1) ||
+      verifyWMMAStoreOp(op, type3))
+    return success();
+
+  return op.emitOpError("expected operands to be a source pointer in memory"
+                        "space 0, 1, 3 followed by ldm of the source");
+}
+
+static LogicalResult verify(WMMAMmaF16F16M16N16K16Op op) {
+  MLIRContext *context = op.getContext();
+  auto f16Ty = FloatType::getF16(context);
+  auto f16x2Ty = VectorType::get(2, f16Ty);
+  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
+      context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+
+  SmallVector<Type, 2> operandTypes(op.getOperandTypes().begin(),
+                                    op.getOperandTypes().end());
+  if (operandTypes != SmallVector<Type, 20>(20, f16x2Ty))
+    return op.emitOpError("expected 20 <halfx2>s as operands");
+
+  if (op.getResult().getType() != f16x2x4StructTy)
+    return op.emitOpError("expected result type to be a struct of 4 <halfx2>s");
+
+  return success();
+}
+
+static LogicalResult parseWMMAMmaF16F16M16N16K16Op(OpAsmParser &parser,
+                                                   OperationState &result) {
+  SmallVector<OpAsmParser::OperandType, 4> operands;
+  ::llvm::SMLoc operandsLoc;
+  Type operandType;
+  Type resType;
+
+  operandsLoc = parser.getCurrentLocation();
+  if (parser.parseOperandList(operands) ||
+      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
+      parser.parseType(operandType) || parser.parseArrow())
+    return failure();
+
+  unsigned numOperands = operands.size();
+  SmallVector<Type> operandTypes(numOperands, operandType);
+  if (parser.parseType(resType))
+    return failure();
+  result.addTypes(resType);
+  if (parser.resolveOperands(operands, operandTypes, operandsLoc,
+                             result.operands))
+    return failure();
+  return success();
+}
+
+static void printWMMAMmaF16F16M16N16K16Op(OpAsmPrinter &p,
+                                          WMMAMmaF16F16M16N16K16Op &op) {
+  p << op.getOperationName();
+  p << ' ';
+  p << op.args();
+  p.printOptionalAttrDict(op->getAttrs(), {});
+  p << " : ";
+  p << op->getOperand(0).getType();
+  p << ' ' << "->";
+  p << ' ';
+  p << ::llvm::ArrayRef<::mlir::Type>(op.res().getType());
+}
+
+static LogicalResult verify(WMMAMmaF32F32M16N16K16Op op) {
+  unsigned numABOperands = 16;
+  unsigned numCOperands = 8;
+  MLIRContext *context = op.getContext();
+  auto f16Ty = FloatType::getF16(context);
+  auto f32Ty = FloatType::getF32(context);
+  auto f16x2Ty = VectorType::get(2, f16Ty);
+  auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
+      context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
+
+  SmallVector<Type> abOpTypes;
+  SmallVector<Type> bOpTypes;
+  SmallVector<Type> cOpTypes;
+
+  for (auto operand : op->getOperands().take_front(numABOperands)) {
+    abOpTypes.push_back(operand.getType());
+  }
+
+  for (auto operand :
+       op->getOperands().drop_front(numABOperands).take_front(numCOperands)) {
+    cOpTypes.push_back(operand.getType());
+  }
+
+  if (abOpTypes != SmallVector<Type>(16, f16x2Ty))
+    return op.emitOpError("expected 16 <halfx2>s for `a` and `b` operand");
+
+  if (cOpTypes != SmallVector<Type>(8, f32Ty))
+    return op.emitOpError("expected 8 f32s for `c` operand");
+
+  if (op.getResult().getType() != f32x8StructTy)
+    return op.emitOpError("expected result type to be a struct of 8 f32s");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
@@ -141,7 +340,8 @@ void NVVMDialect::initialize() {
 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
       >();
 
-  // Support unknown operations because not all NVVM operations are registered.
+  // Support unknown operations because not all NVVM operations are
+  // registered.
   allowUnknownOperations();
 }
 

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index bdcb451323add..7675387690a44 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -22,6 +22,7 @@
 using namespace mlir;
 using namespace mlir::LLVM;
 using mlir::LLVM::detail::createIntrinsicCall;
+using mlir::LLVM::detail::createNvvmIntrinsicCall;
 
 static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
                                                   bool withPredicate) {

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ea4c32cf6a2ed..fb5319546adc3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -35,6 +35,7 @@
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
@@ -300,6 +301,29 @@ llvm::Value *mlir::LLVM::detail::createIntrinsicCall(
   return builder.CreateCall(fn, args);
 }
 
+llvm::Value *
+mlir::LLVM::detail::createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
+                                            llvm::Intrinsic::ID intrinsic,
+                                            ArrayRef<llvm::Value *> args) {
+  llvm::Module *module = builder.GetInsertBlock()->getModule();
+  llvm::Function *fn;
+  if (llvm::Intrinsic::isOverloaded(intrinsic)) {
+    if (intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16 &&
+        intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32) {
+      // NVVM load and store instrinsic names are overloaded on the
+      // source/destination pointer type. Pointer is the first argument in the
+      // corresponding NVVM Op.
+      fn = llvm::Intrinsic::getDeclaration(module, intrinsic,
+                                           {args[0]->getType()});
+    } else {
+      fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {});
+    }
+  } else {
+    fn = llvm::Intrinsic::getDeclaration(module, intrinsic);
+  }
+  return builder.CreateCall(fn, args);
+}
+
 /// Given a single MLIR operation, create the corresponding LLVM IR operation
 /// using the `builder`.
 LogicalResult

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 4879c51479e06..58eca3b875855 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -458,3 +458,116 @@ func @memcpy_incompatible_shape(%dst : memref<7xf32>, %src : memref<9xf32>) {
   // expected-error @+1 {{'gpu.memcpy' op arguments have incompatible shape}}
   gpu.memcpy %dst, %src  : memref<7xf32>, memref<9xf32>
 }
+
+// -----
+
+func @mmamatrix_invalid_shape(){
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = constant 16 : index
+    // expected-error @+1 {{MMAMatrixType must have exactly two dimensions}}
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16x16xf16, "AOp">
+    return
+}
+
+// -----
+
+func @mmamatrix_operand_type(){
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = constant 16 : index
+    // expected-error @+1 {{operand expected to be one of AOp, BOp, COp or DOp}}
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "EOp">
+    return
+}
+
+// -----
+
+func @mmamatrix_invalid_element_type(){
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = constant 16 : index
+    // expected-error @+1 {{MMAMatrixType elements must be F16 or F32}}
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp">
+    return
+}
+
+// -----
+
+#layout_map_col_major = affine_map<(i, j) -> (j, i)>
+
+func @mmaLoadOp_identity_layout(){
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
+    %i = constant 16 : index
+    // expected-error @+1 {{expected identity layout map for source memref}}
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, #layout_map_col_major, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    return
+}
+
+// -----
+
+func @mmaLoadOp_invalid_mem_space(){
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 5>
+    %i = constant 16 : index
+    // expected-error @+1 {{source memorySpace kGenericMemorySpace, kSharedMemorySpace or kGlobalMemorySpace only allowed}}
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 5> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    return
+}
+
+// -----
+
+func @mmaLoadOp_operand_type(){
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = constant 16 : index
+    // expected-error @+1 {{only AOp, BOp and COp can be loaded}}
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "DOp">
+    return
+}
+
+// -----
+
+#layout_map_col_major = affine_map<(i, j) -> (j, i)>
+
+func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+    %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
+    %i = constant 16 : index
+    %j = constant 16 : index
+    // expected-error @+1 {{expected identity layout map for destination memref}}
+    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16,#layout_map_col_major, 3>
+    return
+}
+
+// -----
+
+func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+    %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5>
+    %i = constant 16 : index
+    %j = constant 16 : index
+    // expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}}
+    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 5>
+    return
+}
+
+// -----
+
+func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp">) -> () {
+    %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
+    %i = constant 16 : index
+    %j = constant 16 : index
+    // expected-error @+1 {{expected the operand matrix being stored to have 'DOp' operand type}}
+    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "AOp">, memref<32x32xf16, 3>
+    return
+}
+
+// -----
+
+func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
+    // expected-error @+1 {{operands must be in the order AOp, BOp, COp}}
+    %D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+    return
+}
+
+// -----
+
+func @wmmaMmaOp_invalid_operand_shapes(%A : !gpu.mma_matrix<16x32xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
+    // expected-error @+1 {{operand shapes do not satisfy matmul constraints}}
+    %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+    return
+}

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 7c1420115fc7b..a98fe1c496838 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -194,4 +194,15 @@ module attributes {gpu.container_module} {
     %1 = gpu.memcpy async [%0] %dst, %src : memref<3x7xf32>, memref<3x7xf32, 1>
     return
   }
+
+  func @mmamatrix_valid_element_type(){
+    // CHECK-LABEL: func @mmamatrix_valid_element_type
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    // CHECK: %[[wg:.*]] = memref.alloca()
+    %i = constant 16 : index
+    // CHECK: %[[i:.*]] = constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    // CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    return
+  }
 }

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 172a79d38c4d5..4c62f38a4af2c 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -843,3 +843,162 @@ module {
     llvm.return
   }
 }
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_mem_space(%arg0: !llvm.ptr<i32, 5>, %arg1: i32) {
+  // expected-error at +1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}}
+  %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 5>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+  llvm.return
+}
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_missing_ldm(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+  // expected-error at +1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}}
+  %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0: (!llvm.ptr<i32, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+  llvm.return
+}
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+  // expected-error at +1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
+  %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+  llvm.return
+}
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+  // expected-error at +1 {{nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
+  %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+  llvm.return
+}
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_BOp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+  // expected-error at +1 {{'nvvm.wmma.m16n16k16.load.b.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 <halfx2>s}}
+  %0 = nvvm.wmma.m16n16k16.load.b.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+  llvm.return
+}
+
+// -----
+
+llvm.func @wmmaLoadOp_invalid_COp(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+  // expected-error at +1 {{'nvvm.wmma.m16n16k16.load.c.f16.row.stride' op expected result type of loadCOp to be a struct of 4 <halfx2>s or 8 f32s}}
+  %0 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+
+  llvm.return
+}
+
+// -----
+
+llvm.func @wmmaStoreOp_invalid_mem_space(%arg0: !llvm.ptr<i32, 5>, %arg1: vector<2 x f16>,
+                            %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+                            %arg4: vector<2 xf16>, %arg5: i32) {
+  // expected-error at +1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}}
+  nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr<i32, 5>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32
+  llvm.return
+}
+
+// -----
+
+llvm.func @wmmaStoreOp_invalid_missing_ldm(%arg0: !llvm.ptr<i32, 3>, %arg1: vector<2 x f16>,
+                            %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+                            %arg4: vector<2 xf16>, %arg5: i32) {
+  // expected-error at +1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}}
+  nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4 : !llvm.ptr<i32, 3>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @gpu_wmma_mma_op_invalid_operands(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+                        %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+                        %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+                        %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+                        %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+                        %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+                        %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+                        %arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
+                        %arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
+                        %arg18: vector<2 x f16>) {
+  // expected-error at +1 {{'nvvm.wmma.m16n16k16.mma.row.row.f16.f16' op expected 20 <halfx2>s as operands}}
+  %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @gpu_wmma_mma_op_results(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+                        %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+                        %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+                        %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+                        %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+                        %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+                        %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+                        %arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
+                        %arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
+                        %arg18: vector<2 x f16>, %arg19: vector<2 x f16>) {
+  // expected-error at +1 {{expected result type to be a struct of 4 <halfx2>s}}
+  %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @gpu_wmma_mma_op_invalid_ab_operands(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+                        %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+                        %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+                        %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+                        %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+                        %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+                        %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+                        %arg14: vector<2 x f16>, %arg15: f32,
+                        %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
+                        %arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) {
+  // expected-error at +1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 16 <halfx2>s for `a` and `b` operand}}
+  %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @gpu_wmma_mma_op_invalid_c_operand(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+                        %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+                        %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+                        %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+                        %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+                        %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+                        %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+                        %arg14: vector<2 x f16>, %arg15: vector<2xf16>,
+                        %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
+                        %arg20: f32, %arg21: f32, %arg22: f32, %arg23: vector<2xf16>) {
+  // expected-error at +1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 8 f32s for `c` operand}}
+  %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+  llvm.return
+}
+
+// -----
+
+llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+                        %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+                        %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+                        %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+                        %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+                        %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+                        %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+                        %arg14: vector<2 x f16>, %arg15: vector<2xf16>,
+                        %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32,
+                        %arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) {
+  // expected-error at +1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected result type to be a struct of 8 f32s}}
+  %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, vector<2xf16>)>
+  llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index c1b29fe515ee0..8fefd7866fd5d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -73,6 +73,43 @@ llvm.func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
+// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
+// in the LLVM NVPTX backend.
+llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
+  // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+
+  llvm.return
+}
+
+// The test below checks the correct mapping of the nvvm.wmma.*.store.* op to the correct intrinsic
+// in the LLVM NVPTX backend.
+llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr<i32, 3>, %arg1: vector<2 x f16>,
+                            %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+                            %arg4: vector<2 xf16>, %arg5: i32) {
+  // CHECK: call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, <2 x half> {{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 %{{.*}})
+  nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr<i32, 3>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32
+  llvm.return
+}
+
+// The test below checks the correct mapping of the nvvm.wmma.*.mma.* op to the correct intrinsic
+// in the LLVM NVPTX backend.
+llvm.func @gpu_wmma_mma_op(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
+                        %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
+                        %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
+                        %arg6: vector<2 x f16>, %arg7: vector<2 x f16>,
+                        %arg8: vector<2 x f16>, %arg9: vector<2 x f16>,
+                        %arg10: vector<2 x f16>, %arg11: vector<2 x f16>,
+                        %arg12: vector<2 x f16>, %arg13: vector<2 x f16>,
+                        %arg14: vector<2 x f16>, %arg15: vector<2 x f16>,
+                        %arg16: vector<2 x f16>, %arg17: vector<2 x f16>,
+                        %arg18: vector<2 x f16>, %arg19: vector<2 x f16>) {
+  // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
+  %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)>
+
+  llvm.return
+}
+
 // This function has the "kernel" attribute attached and should appear in the
 // NVVM annotations after conversion.
 llvm.func @kernel_func() attributes {nvvm.kernel} {


        


More information about the Mlir-commits mailing list