[Mlir-commits] [mlir] 3be7c28 - [mlir][NVVM] Add support for nvvm mma.sync ops

Thomas Raoux llvmlistbot at llvm.org
Fri Mar 25 10:32:07 PDT 2022


Author: Christopher Bate
Date: 2022-03-25T17:28:05Z
New Revision: 3be7c2891798025e524fce1b7cdaa27aee6c9816

URL: https://github.com/llvm/llvm-project/commit/3be7c2891798025e524fce1b7cdaa27aee6c9816
DIFF: https://github.com/llvm/llvm-project/commit/3be7c2891798025e524fce1b7cdaa27aee6c9816.diff

LOG: [mlir][NVVM] Add support for nvvm mma.sync ops

This patch adds MLIR NVVM support for the various NVPTX `mma.sync`
operations. There are a number of possible data type, shape,
and other attribute combinations supported by the operation, so a
custom assebmly format is added and attributes are inferred where
possible.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D122410

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
    mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index f6513a242d3fc..c44d792abf01d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -35,6 +35,8 @@ set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
 mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
 mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
 mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(NVVMOpsStructs.h.inc -gen-struct-attr-decls)
+mlir_tablegen(NVVMOpsStructs.cpp.inc -gen-struct-attr-defs)
 mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm)
 mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm)
 add_public_tablegen_target(MLIRNVVMConversionsIncGen)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index de942f6fb4d31..7be5f49c0326a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -21,6 +21,7 @@
 #include "llvm/IR/IntrinsicsNVPTX.h"
 
 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc"
+#include "mlir/Dialect/LLVMIR/NVVMOpsStructs.h.inc"
 
 namespace mlir {
 namespace NVVM {

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index d65b525eacf6f..f9d32f480888a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -195,18 +195,6 @@ def NVVM_CpAsyncWaitGroupOp : NVVM_Op<"cp.async.wait.group">,
   let assemblyFormat = "$n attr-dict";
 }
 
-def NVVM_MmaOp :
-  NVVM_Op<"mma.sync">,
-  Results<(outs LLVM_Type:$res)>,
-  Arguments<(ins Variadic<LLVM_Type>:$args)> {
-  string llvmBuilder = [{
-    $res = createIntrinsicCall(
-        builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
-  }];
-  let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
-  let hasVerifier = 1;
-}
-
 /// Helpers to instantiate 
diff erent version of wmma intrinsics.
 /// This matches the hierarchy used in IntrinsicsNVVM.td to define all the
 /// combinations of the intrinsics.
@@ -296,6 +284,7 @@ class MMA_LDST_OPS<list<GEOM> Geom, list<string> Frags, list<string> Types> {
 // Creates list of valid combinations of fragments. This is a subset of what
 // llvm supports and can be extended as needed.
 class NVVM_MMA_OPS {
+  // "wmma" operations
   list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
             [GEOM<16, 16, 8>],
             ["tf32"], [], ["f32"], []>.ret;
@@ -324,6 +313,32 @@ class NVVM_MMA_OPS {
   // Separate A/B/C fragments (loads) from D (stores).
   list<WMMA_REGS> all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d"));
   list<WMMA_REGS> all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d"));
+
+  // "mma_sync" operations
+  list<list<WMMA_REGS>> tf32_mma_ops = MMA_OPS<
+            [GEOM<16,8,4>, GEOM<16,8,8>],
+            ["tf32"], [], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> bf16_mma_ops = MMA_OPS<
+            [GEOM<16,8,16>, GEOM<16,8,8>],
+            ["bf16"], [], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
+            [GEOM<8,8,4>],
+            ["f64"], [], ["f64"], []>.ret;
+  list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
+            [GEOM<8,8,4>, GEOM<16,8,8>, GEOM<16,8,16>],
+            ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
+  list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
+            [GEOM<8,8,16>, GEOM<16,8,16>, GEOM<16,8,32>],
+            ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
+  list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
+            [GEOM<8,8,32>, GEOM<16,8,32>, GEOM<16,8,64>],
+            ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
+  list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS<
+            [GEOM<8,8,128>, GEOM<16,8,128>, GEOM<16,8,256>],
+            ["b1"], [], ["s32"], []>.ret;
+  list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat(
+            tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
+            fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -405,6 +420,150 @@ class MMA_MMA_INTR<string opName> {
   string id = !foldl("", f, acc, el, acc # "\n" # el);
 }
 
+/// Enum attribute for binary (b1) MMA operation type
+def MMAB1OpNone : I32EnumAttrCase<"none", 0>;
+def MMAB1OpXorPopc : I32EnumAttrCase<"xor_popc", 1>;
+def MMAB1OpAndPopc : I32EnumAttrCase<"and_popc", 2>;
+def MMAB1Op : I32EnumAttr<"MMAB1Op", "MMA binary operations",
+  [MMAB1OpNone, MMAB1OpXorPopc, MMAB1OpAndPopc]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def MMAB1OpAttr : EnumAttr<NVVM_Dialect, MMAB1Op, "mma_b1op"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+/// Enum attribute type for the overflow behavior of MMA integer operations
+def MMAIntOverflowWrap : I32EnumAttrCase<"wrapped", 0>;
+def MMAIntOverflowSat : I32EnumAttrCase<"satfinite", 1>;
+def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options",
+  [MMAIntOverflowSat, MMAIntOverflowWrap]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflow"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+/// Attribute to hold the MMA shape
+def NVVM_MMAShapeAttr : StructAttr<"MMAShapeAttr", NVVM_Dialect, [
+    StructFieldAttr<"m", I32Attr>,
+    StructFieldAttr<"n", I32Attr>,
+    StructFieldAttr<"k", I32Attr>
+  ]> {
+  let summary = "Attribute for MMA operation shape.";
+}
+
+// Returns true if this combination of layout/satf for MMA ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SUPPORTED<...>.ret then
+//   def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
+  // MMA ops check both layouts.
+  string layout = layout_a # ":" # layout_b;
+  string a_type = frags[0].ptx_elt_type;
+  string b_type = frags[1].ptx_elt_type;
+  string c_type = frags[2].ptx_elt_type;
+  string d_type = frags[3].ptx_elt_type;
+  string geom = frags[0].geom;
+
+  // gcd is a shortcut used to identify instructions that depend on
+  // geom+frag_c+frag_d.
+  string gcd = geom # ":" # c_type # d_type;
+  bit ret = !cond(
+
+    // Limit satf to valid types
+    !and(!eq(satf, 1),
+         !ne(a_type, "s8"),
+         !ne(a_type, "u8"),
+         !ne(a_type, "s4"),
+         !ne(a_type, "u4")): false,
+
+    // m8n8k4 has no C=f32 D=f16 variant.
+    !eq(gcd, "m8n8k4:f32f16"): false,
+
+    // only m8n8k4 for f16 does not require row:col layout
+    !and(!ne(layout, "row:col"),
+         !or(!ne(geom, "m8n8k4"),
+             !ne(a_type, "f16"))) : false,
+
+    // m16n8k8 requires A and B to be the same type and C and D to be the same
+    // type.
+    !and(!eq(geom, "m16n8k8"),
+         !or(!ne(a_type, b_type),
+             !ne(c_type, d_type))): false,
+
+    // m16n8k8 requires C and D to be the same type.
+    !and(!eq(geom, "m16n8k8"),
+         !ne(c_type, d_type)): false,
+
+    // All other are OK.
+    true: true
+  );
+}
+
+// Returns a list of operation suffixes corresponding to possible b1 
+// multiply-and-accumulate operations for all fragments which have a 
+// b1 type. For all other fragments, the list returned holds a list 
+// containing the empty string.
+class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
+  list<string> ret = !cond(
+    !eq(frags[0].ptx_elt_type, "b1") : ["xor_popc", "and_popc"],
+    true: [""]
+  );
+}
+
+/// Generate enum value of the mma.sync intrinsic.
+class MMA_SYNC_NAME<string ALayout, string BLayout, string b1op, int Satfinite,
+               WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string id = "llvm::Intrinsic::nvvm_mma"
+                # !if(!ne(b1op, ""), "_" # b1op, "")
+                # "_" # A.geom
+                # "_" # ALayout
+                # "_" # BLayout
+                # !if(Satfinite, "_satfinite", "")
+                # signature;  
+}
+
+/// Helper to create the mapping between the configuration and the mma.sync
+/// intrinsic enum value.
+class MMA_SYNC_INTR {
+  list<list<list<list<list<string>>>>> cond0 =
+    !foreach(op, NVVM_MMA_OPS.all_mma_sync_ops,
+      !foreach(layoutA, ["row", "col"],
+        !foreach(layoutB, ["row", "col"],
+          !foreach (sat, [0, 1],
+            !foreach (b1op, NVVM_MMA_B1OPS<op>.ret,
+              !if(NVVM_MMA_SUPPORTED<[op[0], op[1], op[2], op[3]],
+                                     layoutA, layoutB, sat>.ret,
+      "if (layoutA == \"" # layoutA #  "\" && layoutB == \"" # layoutB #  "\" && "
+      "    m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k #
+      "    && \"" # op[0].ptx_elt_type # "\" == eltypeA && \""
+       # op[1].ptx_elt_type # "\" == eltypeB && " 
+       # " \"" # op[2].ptx_elt_type # "\" == eltypeC && "
+       # " \"" # op[3].ptx_elt_type # "\" == eltypeD "
+       # " && (sat.hasValue()  ? " # sat # " == static_cast<int>(*sat) : true)"
+       # !if(!ne(b1op, ""), " && (b1Op.hasValue() ? MMAB1Op::" # b1op # " == b1Op.getValue() : true)", "") # ")\n"
+       # "  return " #
+       MMA_SYNC_NAME<layoutA, layoutB, b1op, sat, op[0], op[1], op[2], op[3]>.id # ";", 
+          "") // if supported
+          ) // b1op
+        ) // sat
+      ) // layoutB
+    ) // layoutA
+  ); // all_mma_sync_ops
+  list<list<list<string>>> f1 = !foldl([[[""]]],
+                                  !foldl([[[[""]]]], cond0, acc, el, 
+                                      !listconcat(acc, el)),
+                                    acc1, el1, !listconcat(acc1, el1));
+  list<list<string>> f2 = !foldl([[""]], f1, acc1, el1, !listconcat(acc1, el1));
+  list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+  string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
 def MMALayoutRow : I32EnumAttrCase<"row", 0>;
 def MMALayoutCol : I32EnumAttrCase<"col", 1>;
 
@@ -418,13 +577,24 @@ def MMALayoutAttr : EnumAttr<NVVM_Dialect, MMALayout, "mma_layout"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
+/// Enum attribute of the 
diff erent PTX element types used for MMA operands.
 def MMATypeF16  : I32EnumAttrCase<"f16", 0>;
 def MMATypeF32  : I32EnumAttrCase<"f32", 1>;
 def MMATypeTF32 : I32EnumAttrCase<"tf32", 2>;
+def MMATypeU8 : I32EnumAttrCase<"u8", 3>;
+def MMATypeS8 : I32EnumAttrCase<"s8", 4>;
+def MMATypeS32 : I32EnumAttrCase<"s32", 5>;
+def MMATypeB1 : I32EnumAttrCase<"b1", 6>;
+def MMATypeU4 : I32EnumAttrCase<"u4", 7>;
+def MMATypeS4 : I32EnumAttrCase<"s4", 8>;
+def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>;
+def MMATypeF64 : I32EnumAttrCase<"f64", 10>;
 
-/// Enum attribute of the 
diff erent matrix types.
 def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types",
-  [MMATypeF16, MMATypeF32, MMATypeTF32]> {
+  [MMATypeF16, MMATypeF32, MMATypeTF32,
+  MMATypeBF16, MMATypeS8, MMATypeU8,
+  MMATypeS32, MMATypeS4, MMATypeU4,
+  MMATypeB1, MMATypeF64]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::NVVM";
 }
@@ -678,4 +848,141 @@ def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
   let hasVerifier = 1;
 }
 
+def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
+
+  let summary = "cooperative matrix-multiply and accumulate";
+
+  let description = [{
+    The `nvvm.mma.sync` operation collectively performs the operation
+    `D = matmul(A, B) + C` using all threads in a warp.    
+
+    All the threads in the warp must execute the same `mma.sync` operation.
+
+    For each possible multiplicand PTX data type, there are one or more possible
+    instruction shapes given as "mMnNkK". The below table describes the posssibilities
+    as well as the types required for the operands. Note that the data type for 
+    C (the accumulator) and D (the result) can vary independently when there are 
+    multiple possibilities in the "C/D Type" column.
+
+    When an optional attribute cannot be immediately inferred from the types of
+    the operands and the result during parsing or validation, an error will be
+    raised.
+
+    `b1Op` is only relevant when the binary (b1) type is given to 
+    `multiplicandDataType`. It specifies how the multiply-and-acumulate is
+    performed and is either `xor_popc` or `and_poc`. The default is `xor_popc`.
+
+    `intOverflowBehavior` is only relevant when the `multiplicandType` attribute 
+    is one of `u8, s8, u4, s4`, this attribute describes how overflow is handled
+    in the accumulator. When the attribute is `satfinite`, the accumulator values
+    are clamped in the int32 range on overflow. This is the default behavior.
+    Alternatively, accumulator behavior `wrapped` can also be specified, in 
+    which case overflow wraps from one end of the range to the other.
+
+    `layoutA` and `layoutB` are required and should generally be set to 
+    `#nvvm.mma_layout<row>` and `#nvvm.mma_layout<col>` respectively, but other
+    combinations are possible for certain layouts according to the table below.
+
+    ```
+    | A/B Type | Shape     | ALayout | BLayout | A Type   | B Type   | C/D Type          |
+    |----------|-----------|---------|---------|----------|----------|-------------------|
+    | f64      | .m8n8k4   | row     | col     | 1x f64   | 1x f64   | 2x f64            |
+    | f16      | .m8n8k4   | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
+    |          | .m16n8k8  | row     | col     | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
+    |          | .m16n8k16 | row     | col     | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
+    | bf16     | .m16n8k8  | row     | col     | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
+    |          | .m16n8k16 | row     | col     | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
+    | tf32     | .m16n8k4  | row     | col     | 2x i32   | 1x i32   | 4x f32            |
+    |          | .m16n8k8  | row     | col     | 4x i32   | 2x i32   | 2x f16x2 or 4 f32 |
+    | u8/s8    | .m8n8k16  | row     | col     | 1x i32   | 1x i32   | 2x i32            |
+    |          | .m16n8k16 | row     | col     | 2x i32   | 1x i32   | 4x i32            |
+    |          | .m16n8k32 | row     | col     | 4x i32   | 2x i32   | 4x i32            |
+    | u4/s4    | .m8n8k32  | row     | col     | 1x i32   | 1x i32   | 2x i32            |
+    |          | m16n8k32  | row     | col     | 2x i32   | 1x i32   | 4x i32            |
+    |          | m16n8k64  | row     | col     | 4x i32   | 2x i32   | 4x i32            |
+    | b1       | m8n8k128  | row     | col     | 1x i32   | 1x i32   | 2x i32            |
+    |          | m16n8k128 | row     | col     | 2x i32   | 1x i32   | 4x i32            |
+    ```
+
+
+    Example:
+    ```mlir
+
+    %128 = nvvm.mma.sync A[%120, %121, %122, %123]  
+                         B[%124, %125]  
+                         C[%126, %127]  
+                         {layoutA = #nvvm.mma_layout<row>, 
+                          layoutB = #nvvm.mma_layout<col>, 
+                          shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} 
+        : (vector<2xf16>, vector<2xf16>, vector<2xf16>)
+           -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+    ```
+  }];
+
+  let results = (outs LLVM_AnyStruct:$res);
+  let arguments = (ins NVVM_MMAShapeAttr:$shape,
+             OptionalAttr<MMAB1OpAttr>:$b1Op, 
+             OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
+             MMALayoutAttr:$layoutA,
+             MMALayoutAttr:$layoutB,
+             OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
+             OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
+             Variadic<LLVM_Type>:$operandA,
+             Variadic<LLVM_Type>:$operandB,
+             Variadic<LLVM_Type>:$operandC); 
+
+  let extraClassDeclaration = !strconcat([{
+      static llvm::Intrinsic::ID getIntrinsicID(
+            int64_t m, int64_t n, uint64_t k,
+            llvm::Optional<MMAB1Op> b1Op, 
+            llvm::Optional<MMAIntOverflow> sat,
+            mlir::NVVM::MMALayout layoutAEnum, mlir::NVVM::MMALayout layoutBEnum,
+            mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+            mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) {
+        llvm::StringRef layoutA = stringifyEnum(layoutAEnum);
+        llvm::StringRef layoutB = stringifyEnum(layoutBEnum);
+        llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+        llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+        llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+        llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum);
+        }],
+        MMA_SYNC_INTR<>.id, [{
+        return 0;
+      }
+
+      static Optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType,
+        bool isAccumulator);
+
+      MMATypes accumPtxType();
+      MMATypes resultPtxType();
+    }]);
+
+  let builders = [
+      OpBuilder<(ins  "Type":$resultType, "ValueRange":$operandA, 
+        "ValueRange":$operandB, "ValueRange":$operandC,
+        "ArrayRef<int64_t>":$shape, "Optional<MMAB1Op>":$b1Op,
+        "Optional<MMAIntOverflow>":$intOverflow,
+        "Optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+        "Optional<std::array<MMALayout, 2>>":$multiplicandLayouts)>
+    ];
+
+  string llvmBuilder = [{
+    auto operands = moduleTranslation.lookupValues(opInst.getOperands());
+    auto intId = mlir::NVVM::MmaOp::getIntrinsicID(
+        $shape.m().getInt(), $shape.n().getInt(), $shape.k().getInt(), 
+        $b1Op, $intOverflowBehavior,
+        $layoutA, $layoutB,
+        $multiplicandAPtxType.getValue(), 
+        $multiplicandBPtxType.getValue(),
+        op.accumPtxType(), 
+        op.resultPtxType());
+
+    $res = createIntrinsicCall(
+      builder, intId, operands);
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
 #endif // NVVMIR_OPS

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 9b909a2b1a5bb..c2c1eb49e1726 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -34,6 +34,7 @@ using namespace NVVM;
 
 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
+#include "mlir/Dialect/LLVMIR/NVVMOpsStructs.cpp.inc"
 
 //===----------------------------------------------------------------------===//
 // Printing/parsing for NVVM ops
@@ -69,47 +70,455 @@ LogicalResult CpAsyncOp::verify() {
   return success();
 }
 
+// Given the element type of an operand and whether or not it is an accumulator,
+// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
+// operand's element type.
+Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
+                                                          bool isAccumulator) {
+  auto half2Type =
+      LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
+  if (operandElType.isF64())
+    return NVVM::MMATypes::f64;
+  if (operandElType.isF16() || operandElType == half2Type)
+    return NVVM::MMATypes::f16;
+  if (operandElType.isF32())
+    return NVVM::MMATypes::f32;
+  if (operandElType.isa<IntegerType>()) {
+    if (isAccumulator)
+      return NVVM::MMATypes::s32;
+    return llvm::None;
+  }
+
+  if (auto structType = operandElType.dyn_cast<LLVM::LLVMStructType>()) {
+    if (structType.getBody().empty())
+      return llvm::None;
+    return inferOperandMMAType(structType.getBody()[0], isAccumulator);
+  }
+
+  return llvm::None;
+}
+
+static bool isInt4PtxType(MMATypes type) {
+  return (type == MMATypes::u4 || type == MMATypes::s4);
+}
+
+static bool isInt8PtxType(MMATypes type) {
+  return (type == MMATypes::u8 || type == MMATypes::s8);
+}
+
+static bool isIntegerPtxType(MMATypes type) {
+  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
+         type == MMATypes::s32;
+}
+
+MMATypes MmaOp::accumPtxType() {
+  Optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
+      getODSOperands(2).getTypes().front(), /*isAccum=*/true);
+  assert(val.hasValue() && "accumulator PTX type should always be inferrable");
+  return val.getValue();
+}
+
+MMATypes MmaOp::resultPtxType() {
+  Optional<mlir::NVVM::MMATypes> val =
+      inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
+  assert(val.hasValue() && "result PTX type should always be inferrable");
+  return val.getValue();
+}
+
+void MmaOp::print(OpAsmPrinter &p) {
+  SmallVector<Type, 4> regTypes;
+  struct OperandFragment {
+    StringRef operandName;
+    StringRef ptxTypeAttr;
+    SmallVector<Value, 4> regs;
+    explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+        : operandName(name), ptxTypeAttr(ptxTypeName) {}
+  };
+
+  std::array<OperandFragment, 3> frags{
+      OperandFragment("A", multiplicandAPtxTypeAttrName()),
+      OperandFragment("B", multiplicandBPtxTypeAttrName()),
+      OperandFragment("C", "")};
+  SmallVector<StringRef, 4> ignoreAttrNames{
+      mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
+
+  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
+    auto &frag = frags[fragIdx];
+    auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
+    for (auto operandIdx = varOperandSpec.first;
+         operandIdx < varOperandSpec.first + varOperandSpec.second;
+         operandIdx++) {
+      frag.regs.push_back(this->getOperand(operandIdx));
+      if (operandIdx == 0) {
+        regTypes.push_back(this->getOperand(operandIdx).getType());
+      }
+    }
+    Optional<MMATypes> inferredType =
+        inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
+    if (inferredType)
+      ignoreAttrNames.push_back(frag.ptxTypeAttr);
+  }
+
+  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
+    p << " " << frag.operandName;
+    p << "[";
+    p.printOperands(frag.regs);
+    p << "] ";
+  };
+
+  for (const auto &frag : frags) {
+    printMmaOperand(frag);
+  }
+
+  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
+
+  // Print the types of the operands and result.
+  p << " : "
+    << "(";
+  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
+                                             frags[1].regs[0].getType(),
+                                             frags[2].regs[0].getType()},
+                        p);
+  p << ")";
+  p.printArrowTypeList(TypeRange{this->res().getType()});
+}
+
+void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+                  ValueRange operandA, ValueRange operandB, ValueRange operandC,
+                  ArrayRef<int64_t> shape, Optional<MMAB1Op> b1Op,
+                  Optional<MMAIntOverflow> intOverflow,
+                  Optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+                  Optional<std::array<MMALayout, 2>> multiplicandLayouts) {
+
+  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+  MLIRContext *ctx = builder.getContext();
+  Type i32 = builder.getIntegerType(32);
+  result.addAttribute(
+      "shape", MMAShapeAttr::get(builder.getIntegerAttr(i32, shape[0]),
+                                 builder.getIntegerAttr(i32, shape[1]),
+                                 builder.getIntegerAttr(i32, shape[2]), ctx));
+
+  result.addOperands(operandA);
+  result.addOperands(operandB);
+  result.addOperands(operandC);
+
+  if (multiplicandPtxTypes.hasValue()) {
+    result.addAttribute("multiplicandAPtxType",
+                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+    result.addAttribute("multiplicandBPtxType",
+                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+  } else {
+    if (auto res = inferOperandMMAType(operandA[0].getType(), false))
+      result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+    if (auto res = inferOperandMMAType(operandB[0].getType(), false))
+      result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+  }
+
+  if (multiplicandLayouts.hasValue()) {
+    result.addAttribute("layoutA",
+                        MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
+    result.addAttribute("layoutB",
+                        MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
+  } else {
+    result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
+    result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
+  }
+
+  if (intOverflow.hasValue())
+    result.addAttribute("intOverflowBehavior",
+                        MMAIntOverflowAttr::get(ctx, *intOverflow));
+  if (b1Op.hasValue())
+    result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
+
+  result.addTypes(resultType);
+  result.addAttribute(
+      MmaOp::getOperandSegmentSizeAttr(),
+      builder.getI32VectorAttr({static_cast<int32_t>(operandA.size()),
+                                static_cast<int32_t>(operandB.size()),
+                                static_cast<int32_t>(operandC.size())}));
+}
+
+// <operation> :=
+//   A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
+//   attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
+//     `->` type($res)
+ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
+  struct OperandFragment {
+    Optional<MMATypes> elemtype;
+    SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+    SmallVector<Type> regTypes;
+  };
+
+  Builder &builder = parser.getBuilder();
+  std::array<OperandFragment, 4> frags;
+
+  NamedAttrList namedAttributes;
+
+  // A helper to parse the operand segments.
+  auto parseMmaOperand = [&](StringRef operandName,
+                             OperandFragment &frag) -> LogicalResult {
+    if (parser.parseKeyword(operandName).failed())
+      return failure();
+    if (parser
+            .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
+            .failed())
+      return failure();
+    return success();
+  };
+
+  // Parse the operand segments.
+  if (parseMmaOperand("A", frags[0]).failed())
+    return failure();
+  if (parseMmaOperand("B", frags[1]).failed())
+    return failure();
+  if (parseMmaOperand("C", frags[2]).failed())
+    return failure();
+
+  if (parser.parseOptionalAttrDict(namedAttributes).failed())
+    return failure();
+
+  // Parse the type specification and resolve operands.
+  SmallVector<Type, 3> operandTypes;
+  if (failed(parser.parseColon()))
+    return failure();
+  if (failed(parser.parseLParen()))
+    return failure();
+  if (failed(parser.parseTypeList(operandTypes)))
+    return failure();
+  if (failed(parser.parseRParen()))
+    if (operandTypes.size() != 3)
+      return parser.emitError(
+          parser.getNameLoc(),
+          "expected one type for each operand segment but got " +
+              Twine(operandTypes.size()) + " types");
+  for (auto iter : llvm::enumerate(operandTypes)) {
+    auto &frag = frags[iter.index()];
+    frag.regTypes.resize(frag.regs.size(), iter.value());
+    if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
+                                      parser.getNameLoc(), result.operands)))
+      return failure();
+    frag.elemtype =
+        inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
+  }
+
+  Type resultType;
+  parser.parseArrow();
+  parser.parseType(resultType);
+  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
+
+  std::array<StringRef, 2> names{"multiplicandAPtxType",
+                                 "multiplicandBPtxType"};
+  for (unsigned idx = 0; idx < names.size(); idx++) {
+    const auto &frag = frags[idx];
+    Optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
+    if (!frag.elemtype.hasValue() && !attr.hasValue()) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "attribute " + names[idx] +
+              " is not provided explicitly and cannot be inferred");
+    }
+    if (!attr.hasValue())
+      result.addAttribute(
+          names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
+  }
+
+  result.addTypes(resultType);
+  if (!namedAttributes.empty())
+    result.addAttributes(namedAttributes);
+  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
+                      builder.getI32VectorAttr({
+                          static_cast<int32_t>(frags[0].regs.size()),
+                          static_cast<int32_t>(frags[1].regs.size()),
+                          static_cast<int32_t>(frags[2].regs.size()),
+                      }));
+  return success();
+}
+
 LogicalResult MmaOp::verify() {
   MLIRContext *context = getContext();
   auto f16Ty = Float16Type::get(context);
+  auto i32Ty = IntegerType::get(context, 32);
   auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
   auto f32Ty = Float32Type::get(context);
   auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
       context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
-  auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
-      context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
-
-  auto operandTypes = getOperandTypes();
-  if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
-      operandTypes != ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
-                                     f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
-                                     f32Ty}) {
-    return emitOpError("expected operands to be 4 <halfx2>s followed by either "
-                       "4 <halfx2>s or 8 floats");
+
+  auto s32x4StructTy =
+      LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
+  auto f32x8StructTy =
+      LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
+  auto f16x2x2StructTy =
+      LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
+  auto f32x4StructTy =
+      LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
+  auto s32x2StructTy =
+      LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
+
+  std::array<int64_t, 3> mmaShape{shapeAttr().m().getInt(),
+                                  shapeAttr().n().getInt(),
+                                  shapeAttr().k().getInt()};
+
+  // These variables define the set of allowed data types for matrices A, B, C,
+  // and result.
+  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
+  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
+  AllowedShapes allowedShapes;
+  AllowedTypes expectedA;
+  AllowedTypes expectedB;
+  AllowedTypes expectedC;
+  SmallVector<Type> expectedResult;
+
+  // When M = 16, we just need to calculate the number of 8xk tiles, where
+  // k is a factor that depends on the data type.
+  if (mmaShape[0] == 16) {
+    int64_t kFactor;
+    Type multiplicandFragType;
+    switch (multiplicandAPtxType().getValue()) {
+    case MMATypes::tf32:
+      kFactor = 4;
+      expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+          context, {i32Ty, i32Ty, i32Ty, i32Ty}));
+      break;
+    case MMATypes::f16:
+    case MMATypes::bf16:
+      kFactor = 8;
+      multiplicandFragType = f16x2Ty;
+      expectedResult.push_back(f16x2x2StructTy);
+      expectedResult.push_back(f32x4StructTy);
+      break;
+    case MMATypes::s4:
+    case MMATypes::u4:
+      kFactor = 32;
+      break;
+    case MMATypes::b1:
+      kFactor = 128;
+      break;
+    case MMATypes::s8:
+    case MMATypes::u8:
+      kFactor = 16;
+      break;
+    default:
+      return emitError("invalid shape or multiplicand type: " +
+                       stringifyEnum(multiplicandAPtxType().getValue()));
+    }
+
+    if (isIntegerPtxType(multiplicandAPtxType().getValue())) {
+      expectedResult.push_back(s32x4StructTy);
+      expectedC.emplace_back(4, i32Ty);
+      multiplicandFragType = i32Ty;
+    } else {
+      expectedC.emplace_back(2, f16x2Ty);
+      expectedC.emplace_back(4, f32Ty);
+    }
+
+    int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
+    int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
+    expectedA.emplace_back(unitA, multiplicandFragType);
+    expectedB.emplace_back(unitB, multiplicandFragType);
+    allowedShapes.push_back({16, 8, kFactor});
+    allowedShapes.push_back({16, 8, kFactor * 2});
   }
-  if (getType() != f32x8StructTy && getType() != f16x2x4StructTy) {
-    return emitOpError("expected result type to be a struct of either 4 "
-                       "<halfx2>s or 8 floats");
+
+  // In the M=8 case, there is only 1 possible case per data type.
+  if (mmaShape[0] == 8) {
+    if (multiplicandAPtxType().getValue() == MMATypes::f16) {
+      expectedA.emplace_back(2, f16x2Ty);
+      expectedB.emplace_back(2, f16x2Ty);
+      expectedResult.push_back(f16x2x4StructTy);
+      expectedResult.push_back(f32x8StructTy);
+      expectedC.emplace_back(4, f16x2Ty);
+      expectedC.emplace_back(8, f32Ty);
+      allowedShapes.push_back({8, 8, 4});
+    }
+    if (multiplicandAPtxType().getValue() == MMATypes::f64) {
+      Type f64Ty = Float64Type::get(context);
+      expectedA.emplace_back(1, f64Ty);
+      expectedB.emplace_back(1, f64Ty);
+      expectedC.emplace_back(2, f64Ty);
+      // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
+      expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
+          context, SmallVector<Type>(2, f64Ty)));
+      allowedShapes.push_back({8, 8, 4});
+    }
+    if (isIntegerPtxType(multiplicandAPtxType().getValue())) {
+      expectedA.push_back({i32Ty});
+      expectedB.push_back({i32Ty});
+      expectedC.push_back({i32Ty, i32Ty});
+      expectedResult.push_back(s32x2StructTy);
+      if (isInt4PtxType(multiplicandAPtxType().getValue()))
+        allowedShapes.push_back({8, 8, 32});
+      if (isInt8PtxType(multiplicandAPtxType().getValue()))
+        allowedShapes.push_back({8, 8, 16});
+      if (multiplicandAPtxType().getValue() == MMATypes::b1)
+        allowedShapes.push_back({8, 8, 128});
+    }
   }
 
-  auto alayout = (*this)->getAttrOfType<StringAttr>("alayout");
-  auto blayout = (*this)->getAttrOfType<StringAttr>("blayout");
+  std::string errorMessage;
+  llvm::raw_string_ostream errorStream(errorMessage);
 
-  if (!(alayout && blayout) ||
-      !(alayout.getValue() == "row" || alayout.getValue() == "col") ||
-      !(blayout.getValue() == "row" || blayout.getValue() == "col")) {
-    return emitOpError("alayout and blayout attributes must be set to either "
-                       "\"row\" or \"col\"");
+  // Check that we matched an existing shape/dtype combination.
+  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
+      !llvm::any_of(allowedShapes,
+                    [&](const auto &allowed) { return allowed == mmaShape; })) {
+    errorStream << "unimplemented variant for MMA shape <";
+    llvm::interleaveComma(mmaShape, errorStream);
+    errorStream << ">";
+    return emitOpError(errorMessage);
   }
 
-  if (operandTypes == ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
-                                     f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
-                                     f32Ty} &&
-      getType() == f32x8StructTy && alayout.getValue() == "row" &&
-      blayout.getValue() == "col") {
-    return success();
+  // Verify the operand types for segments of A, B, and C operands.
+  std::array<StringRef, 3> operandNames{"A", "B", "C"};
+  for (const auto &iter : llvm::enumerate(
+           SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
+    auto spec = this->getODSOperandIndexAndLength(iter.index());
+    SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
+                                      operand_type_begin() + spec.first +
+                                          spec.second);
+    bool match =
+        llvm::any_of(iter.value(), [&](const SmallVector<Type, 4> &typeSet) {
+          return typeSet == operandTySeg;
+        });
+
+    if (!match) {
+      errorStream << "Could not match types for the "
+                  << operandNames[iter.index()]
+                  << " operands; expected one of ";
+      for (const auto &x : iter.value()) {
+        errorStream << x.size() << "x" << x[0] << " ";
+      }
+      errorStream << "but got ";
+      llvm::interleaveComma(operandTySeg, errorStream);
+      return emitOpError(errorStream.str());
+    }
+  }
+
+  // Check the result type
+  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
+        return expectedResultType == getResult().getType();
+      })) {
+    errorStream
+        << "Could not match allowed types for the result; expected one of ";
+    llvm::interleaveComma(expectedResult, errorStream);
+    errorStream << " but got " << getResult().getType();
+    return emitOpError(errorStream.str());
+  }
+
+  // Ensure that binary MMA variants have a b1 MMA operation defined.
+  if (multiplicandAPtxType() == MMATypes::b1 && !b1Op().hasValue()) {
+    return emitOpError("op requires " + b1OpAttrName().strref() + " attribute");
   }
-  return emitOpError("unimplemented mma.sync variant");
+
+  // Ensure int4/int8 MMA variants specify the accum overflow behavior
+  // attribute.
+  if (isInt4PtxType(*multiplicandAPtxType()) ||
+      isInt8PtxType(*multiplicandAPtxType())) {
+    if (!intOverflowBehavior().hasValue())
+      return emitOpError("op requires " +
+                         intOverflowBehaviorAttrName().strref() + " attribute");
+  }
+
+  return success();
 }
 
 LogicalResult ShflOp::verify() {

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 6c7e3ae5712d7..cc4ae43b7ae3e 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -514,12 +514,13 @@ func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i3
 
 // -----
 
-func @nvvm_invalid_mma_0(%a0 : f16, %a1 : vector<2xf16>,
+func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
                          %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                          %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
                          %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
-  // expected-error at +1 {{expected operands to be 4 <halfx2>s followed by either 4 <halfx2>s or 8 floats}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (f16, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+  // expected-error at +1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}}
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] 
+    {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
@@ -529,8 +530,9 @@ func @nvvm_invalid_mma_1(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                          %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                          %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
                          %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
-  // expected-error at +1 {{expected result type to be a struct of either 4 <halfx2>s or 8 floats}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
+  // expected-error at +1 {{Could not match allowed types for the result; expected one of !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> but got !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>}}
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
+    {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
 }
 
@@ -540,8 +542,9 @@ func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                          %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                          %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
                          %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
-  // expected-error at +1 {{alayout and blayout attributes must be set to either "row" or "col"}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+  // expected-error at +1 {{op requires attribute 'layoutA'}}
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] 
+    {shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
@@ -549,55 +552,23 @@ func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
 
 func @nvvm_invalid_mma_3(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                          %b0 : vector<2xf16>, %b1 : vector<2xf16>,
-                         %c0 : vector<2xf16>, %c1 : vector<2xf16>,
-                         %c2 : vector<2xf16>, %c3 : vector<2xf16>) {
-  // expected-error at +1 {{unimplemented mma.sync variant}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
-  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
-}
-
-// -----
-
-func @nvvm_invalid_mma_4(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
-                         %b0 : vector<2xf16>, %b1 : vector<2xf16>,
-                         %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
-                         %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
-  // expected-error at +1 {{unimplemented mma.sync variant}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+                         %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
+  // expected-error at +1 {{unimplemented variant for MMA shape <8, 8, 16>}}
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
 // -----
 
-func @nvvm_invalid_mma_5(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
-                         %b0 : vector<2xf16>, %b1 : vector<2xf16>,
-                         %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
-                         %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
-  // expected-error at +1 {{unimplemented mma.sync variant}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
-  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
-}
-
-// -----
-
-func @nvvm_invalid_mma_6(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
-                         %b0 : vector<2xf16>, %b1 : vector<2xf16>,
-                         %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
-                         %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
-  // expected-error at +1 {{invalid kind of type specified}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
-  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
-}
-
-// -----
-
-func @nvvm_invalid_mma_7(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
-                         %b0 : vector<2xf16>, %b1 : vector<2xf16>,
-                         %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
-                         %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
-  // expected-error at +1 {{op requires one result}}
-  %0:2 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> (!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, i32)
-  llvm.return %0#0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
+                               %b0 : i32,
+                               %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
+  // expected-error at +1 {{op requires b1Op attribute}}
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,     
+     shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 2b191541c3b02..c2e1db76f251c 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -66,15 +66,164 @@ func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
   llvm.return %0 : i32
 }
 
-func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+// CHECK-LABEL: @nvvm_mma_m8n8k4_row_col_f32_f32
+func @nvvm_mma_m8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
-               %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
-               %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
-  // CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout = "row", blayout = "col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+               %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
+  // CHECK: nvvm.mma.sync
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
+func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+                              %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+                              %c0 : vector<2xf16>, %c1 : vector<2xf16>, %c2 : vector<2xf16>, %c3 : vector<2xf16>) {  
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}]
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>     
+  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+}
+
+func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
+                             %c0 : i32, %c1 : i32) {                             
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> 
+  %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
+     intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
+     shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
+  llvm.return %0 : !llvm.struct<(i32, i32)>
+}
+
+func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,                                
+                               %b0 : vector<2xf16>,
+                               %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
+  // CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>     
+  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+                                %c0 : vector<2xf16>, %c1 : vector<2xf16>) {  
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+                                %c0 : vector<2xf16>, %c1 : vector<2xf16>) {                                  
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+                                %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {  
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+                                %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {    
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, 
+                              %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {  
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
+     intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
+     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,                                
+                                %b0 : i32, 
+                                %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {  
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>,
+     intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
+     shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
+                               %b0 : i32, %b1 : i32,
+                               %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {  
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,     
+     b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
+                               %b0 : i32,
+                               %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {  
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,     
+     b1Op = #nvvm.mma_b1op<xor_popc>,
+     shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_m8n8k128_b1_b1
+func @nvvm_mma_m8n8k128_b1_b1(%a0 : i32,
+                              %b0 : i32,
+                              %c0 : i32, %c1 : i32) {  
+  // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
+  %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
+     b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32)>
+  llvm.return %0 : !llvm.struct<(i32,i32)>
+}
+
+// CHECK-LABEL: @nvvm_mma_m16n8k32_s4_s4
+func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
+                               %b0 : i32,
+                               %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {    
+  // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
+     intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
+     shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+// CHECK-LABEL: @nvvm_wmma_load_tf32
 func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr<i32>, %arg1 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
   // CHECK: nvvm.wmma.load {{.*}} {eltype = #nvvm.mma_type<tf32>, frag = #nvvm.mma_frag<a>, k = 8 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32}
   %0 = nvvm.wmma.load %arg0, %arg1

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index b62913b7c2737..ad73a295359df 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -88,17 +88,124 @@ llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
   llvm.return %3 : i32
 }
 
-llvm.func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+// CHECK-LABEL: @nvvm_mma_mn8n8k4_row_col_f32_f32
+llvm.func @nvvm_mma_mn8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                     %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                     %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
                     %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> {
   // CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] 
+  {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
+llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+                                %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+  // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f16
+  %0 = nvvm.mma.sync A[ %a0, %a1, %a2, %a3 ] B[ %b0, %b1 ] C[ %c0, %c1 ]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}}
+     : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// f32 return type, f16 accumulate type
+llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+                                %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> {                                
+  // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f16
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// f16 return type, f32 accumulate type
+llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+                                %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+  // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f32
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// f32 return type, f32 accumulate type
+llvm.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+                                %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
+  // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f32                                 
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32,                                
+                                %b0 : i32, 
+                                %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.s8
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, 
+     intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
+     shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,                                
+                                %b0 : i32, 
+                                %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.satfinite.s8.u8
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>,     
+     intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
+     shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32, 
+                                    %b0 : i32,
+                                    %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {  
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,     
+     b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
+                               %b0 : i32,
+                               %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {  
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k32.row.col.satfinite.s4
+  %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
+     intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
+     shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
+  llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
+}
+
+llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64,
+                                   %b0 : f64, 
+                                   %c0 : f64, %c1 : f64) -> !llvm.struct<(f64, f64)> {
+  // CHECK: call { double, double } @llvm.nvvm.mma.m8n8k4.row.col.f64
+  %0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,          
+     shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (f64, f64, f64) -> !llvm.struct<(f64, f64)>
+  llvm.return %0 : !llvm.struct<(f64, f64)>
+}
+
 // The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
 // in the LLVM NVPTX backend.
+// CHECK-LABEL: @gpu_wmma_load_op
 llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
   // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}})
   %0 = nvvm.wmma.load %arg0, %arg1


        


More information about the Mlir-commits mailing list