[llvm] 3644726 - [Clang][NVPTX] Add NVPTX intrinsics and builtins for CUDA PTX 6.5 and 7.0 WMMA and MMA instructions

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 29 15:44:56 PDT 2021


Author: Steffen Larsen
Date: 2021-06-29T15:44:07-07:00
New Revision: 3644726a78e37823b1687a7aa8d186e91570ffe2

URL: https://github.com/llvm/llvm-project/commit/3644726a78e37823b1687a7aa8d186e91570ffe2
DIFF: https://github.com/llvm/llvm-project/commit/3644726a78e37823b1687a7aa8d186e91570ffe2.diff

LOG: [Clang][NVPTX] Add NVPTX intrinsics and builtins for CUDA PTX 6.5 and 7.0 WMMA and MMA instructions

Adds NVPTX builtins and intrinsics for the CUDA PTX `wmma.load`, `wmma.store`, `wmma.mma`, and `mma` instructions added in PTX 6.5 and 7.0.

PTX ISA description of

  - `wmma.load`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-ld
  - `wmma.store`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-st
  - `wmma.mma`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-mma
  - `mma`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma

Overview of `wmma.mma` and `mma` matrix shape/type combinations added with specific PTX versions: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape

Authored-by: Steffen Larsen <steffen.larsen at codeplay.com>
Co-Authored-by: Stuart Adams <stuart.adams at codeplay.com>

Reviewed By: tra

Differential Revision: https://reviews.llvm.org/D104847

Added: 
    

Modified: 
    clang/include/clang/Basic/BuiltinsNVPTX.def
    clang/lib/CodeGen/CGBuiltin.cpp
    clang/test/CodeGen/builtins-nvptx-mma.cu
    clang/test/CodeGen/builtins-nvptx-mma.py
    llvm/include/llvm/IR/IntrinsicsNVVM.td
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
    llvm/test/CodeGen/NVPTX/lit.local.cfg
    llvm/test/CodeGen/NVPTX/wmma.py

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def
index 98f3c659b7cec..e815138a15c15 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.def
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -759,6 +759,29 @@ TARGET_BUILTIN(__imma_m8n8k32_mma_s4, "vi*iC*iC*iC*IiIi", "", AND(SM_75,PTX63))
 TARGET_BUILTIN(__imma_m8n8k32_mma_u4, "vi*iC*iC*iC*IiIi", "", AND(SM_75,PTX63))
 TARGET_BUILTIN(__imma_m8n8k32_st_c_i32, "vi*iC*UiIi", "", AND(SM_75,PTX63))
 
+// Builtins to support double and alternate float WMMA instructions on sm_80
+TARGET_BUILTIN(__dmma_m8n8k4_ld_a, "vd*dC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__dmma_m8n8k4_ld_b, "vd*dC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__dmma_m8n8k4_ld_c, "vd*dC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__dmma_m8n8k4_st_c_f64, "vd*dC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__dmma_m8n8k4_mma_f64, "vd*dC*dC*dC*IiIi", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__mma_bf16_m16n16k16_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m16n16k16_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m16n16k16_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m8n32k16_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m8n32k16_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m8n32k16_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m32n8k16_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m32n8k16_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m32n8k16_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__mma_tf32_m16n16k8_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_tf32_m16n16k8_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_tf32_m16n16k8_ld_c, "vf*fC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_m16n16k8_st_c_f32, "vf*fC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_tf32_m16n16k8_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70))
+
 // Async Copy
 TARGET_BUILTIN(__nvvm_cp_async_mbarrier_arrive, "vWi*", "", AND(SM_80,PTX70))
 TARGET_BUILTIN(__nvvm_cp_async_mbarrier_arrive_shared, "vWi*3", "", AND(SM_80,PTX70))

diff  --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 2f2d5e6c83d77..3fc9ba414397e 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -16402,6 +16402,34 @@ static NVPTXMmaLdstInfo getNVPTXMmaLdstInfo(unsigned BuiltinID) {
   case NVPTX::BI__bmma_m8n8k128_ld_c:
     return MMA_LDST(2, m8n8k128_load_c_s32);
 
+  // Double MMA loads
+  case NVPTX::BI__dmma_m8n8k4_ld_a:
+    return MMA_LDST(1, m8n8k4_load_a_f64);
+  case NVPTX::BI__dmma_m8n8k4_ld_b:
+    return MMA_LDST(1, m8n8k4_load_b_f64);
+  case NVPTX::BI__dmma_m8n8k4_ld_c:
+    return MMA_LDST(2, m8n8k4_load_c_f64);
+
+  // Alternate float MMA loads
+  case NVPTX::BI__mma_bf16_m16n16k16_ld_a:
+    return MMA_LDST(4, m16n16k16_load_a_bf16);
+  case NVPTX::BI__mma_bf16_m16n16k16_ld_b:
+    return MMA_LDST(4, m16n16k16_load_b_bf16);
+  case NVPTX::BI__mma_bf16_m8n32k16_ld_a:
+    return MMA_LDST(2, m8n32k16_load_a_bf16);
+  case NVPTX::BI__mma_bf16_m8n32k16_ld_b:
+    return MMA_LDST(8, m8n32k16_load_b_bf16);
+  case NVPTX::BI__mma_bf16_m32n8k16_ld_a:
+    return MMA_LDST(8, m32n8k16_load_a_bf16);
+  case NVPTX::BI__mma_bf16_m32n8k16_ld_b:
+    return MMA_LDST(2, m32n8k16_load_b_bf16);
+  case NVPTX::BI__mma_tf32_m16n16k8_ld_a:
+    return MMA_LDST(4, m16n16k8_load_a_tf32);
+  case NVPTX::BI__mma_tf32_m16n16k8_ld_b:
+    return MMA_LDST(2, m16n16k8_load_b_tf32);
+  case NVPTX::BI__mma_tf32_m16n16k8_ld_c:
+    return MMA_LDST(8, m16n16k8_load_c_f32);
+
   // NOTE: We need to follow inconsitent naming scheme used by NVCC.  Unlike
   // PTX and LLVM IR where stores always use fragment D, NVCC builtins always
   // use fragment C for both loads and stores.
@@ -16433,6 +16461,14 @@ static NVPTXMmaLdstInfo getNVPTXMmaLdstInfo(unsigned BuiltinID) {
   case NVPTX::BI__bmma_m8n8k128_st_c_i32:
     return MMA_LDST(2, m8n8k128_store_d_s32);
 
+  // Double MMA store
+  case NVPTX::BI__dmma_m8n8k4_st_c_f64:
+    return MMA_LDST(2, m8n8k4_store_d_f64);
+
+  // Alternate float MMA store
+  case NVPTX::BI__mma_m16n16k8_st_c_f32:
+    return MMA_LDST(8, m16n16k8_store_d_f32);
+
   default:
     llvm_unreachable("Unknown MMA builtin");
   }
@@ -16446,10 +16482,14 @@ struct NVPTXMmaInfo {
   unsigned NumEltsB;
   unsigned NumEltsC;
   unsigned NumEltsD;
+
+  // Variants are ordered by layout-A/layout-B/satf, where 'row' has priority
+  // over 'col' for layout. The index of non-satf variants is expected to match
+  // the undocumented layout constants used by CUDA's mma.hpp.
   std::array<unsigned, 8> Variants;
 
   unsigned getMMAIntrinsic(int Layout, bool Satf) {
-    unsigned Index = Layout * 2 + Satf;
+    unsigned Index = Layout + 4 * Satf;
     if (Index >= Variants.size())
       return 0;
     return Variants[Index];
@@ -16460,93 +16500,107 @@ struct NVPTXMmaInfo {
   // Layout and Satf, 0 otherwise.
 static NVPTXMmaInfo getNVPTXMmaInfo(unsigned BuiltinID) {
   // clang-format off
-#define MMA_VARIANTS(geom, type) {{                                 \
+#define MMA_VARIANTS(geom, type)                                    \
       Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type,             \
-      Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \
       Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type,             \
-      Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
       Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type,             \
+      Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type
+#define MMA_SATF_VARIANTS(geom, type)                               \
+      MMA_VARIANTS(geom, type),                                     \
+      Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \
+      Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
       Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \
-      Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type,             \
-      Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite  \
-    }}
+      Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite
 // Sub-integer MMA only supports row.col layout.
-#define MMA_VARIANTS_I4(geom, type) {{ \
-      0, \
+#define MMA_VARIANTS_I4(geom, type) \
       0, \
       Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type,             \
-      Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
       0, \
       0, \
       0, \
-      0  \
-    }}
-// b1 MMA does not support .satfinite.
-#define MMA_VARIANTS_B1(geom, type) {{ \
+      Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
       0, \
+      0
+// b1 MMA does not support .satfinite.
+#define MMA_VARIANTS_B1(geom, type) \
       0, \
       Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type,             \
       0, \
       0, \
       0, \
       0, \
-      0  \
-    }}
-    // clang-format on
-    switch (BuiltinID) {
-    // FP MMA
-    // Note that 'type' argument of MMA_VARIANT uses D_C notation, while
-    // NumEltsN of return value are ordered as A,B,C,D.
-    case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
-      return {8, 8, 4, 4, MMA_VARIANTS(m16n16k16, f16_f16)};
-    case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
-      return {8, 8, 4, 8, MMA_VARIANTS(m16n16k16, f32_f16)};
-    case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
-      return {8, 8, 8, 4, MMA_VARIANTS(m16n16k16, f16_f32)};
-    case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
-      return {8, 8, 8, 8, MMA_VARIANTS(m16n16k16, f32_f32)};
-    case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
-      return {8, 8, 4, 4, MMA_VARIANTS(m32n8k16, f16_f16)};
-    case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
-      return {8, 8, 4, 8, MMA_VARIANTS(m32n8k16, f32_f16)};
-    case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
-      return {8, 8, 8, 4, MMA_VARIANTS(m32n8k16, f16_f32)};
-    case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
-      return {8, 8, 8, 8, MMA_VARIANTS(m32n8k16, f32_f32)};
-    case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
-      return {8, 8, 4, 4, MMA_VARIANTS(m8n32k16, f16_f16)};
-    case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
-      return {8, 8, 4, 8, MMA_VARIANTS(m8n32k16, f32_f16)};
-    case NVPTX::BI__hmma_m8n32k16_mma_f16f32:
-      return {8, 8, 8, 4, MMA_VARIANTS(m8n32k16, f16_f32)};
-    case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
-      return {8, 8, 8, 8, MMA_VARIANTS(m8n32k16, f32_f32)};
-
-    // Integer MMA
-    case NVPTX::BI__imma_m16n16k16_mma_s8:
-      return {2, 2, 8, 8, MMA_VARIANTS(m16n16k16, s8)};
-    case NVPTX::BI__imma_m16n16k16_mma_u8:
-      return {2, 2, 8, 8, MMA_VARIANTS(m16n16k16, u8)};
-    case NVPTX::BI__imma_m32n8k16_mma_s8:
-      return {4, 1, 8, 8, MMA_VARIANTS(m32n8k16, s8)};
-    case NVPTX::BI__imma_m32n8k16_mma_u8:
-      return {4, 1, 8, 8, MMA_VARIANTS(m32n8k16, u8)};
-    case NVPTX::BI__imma_m8n32k16_mma_s8:
-      return {1, 4, 8, 8, MMA_VARIANTS(m8n32k16, s8)};
-    case NVPTX::BI__imma_m8n32k16_mma_u8:
-      return {1, 4, 8, 8, MMA_VARIANTS(m8n32k16, u8)};
-
-    // Sub-integer MMA
-    case NVPTX::BI__imma_m8n8k32_mma_s4:
-      return {1, 1, 2, 2, MMA_VARIANTS_I4(m8n8k32, s4)};
-    case NVPTX::BI__imma_m8n8k32_mma_u4:
-      return {1, 1, 2, 2, MMA_VARIANTS_I4(m8n8k32, u4)};
-    case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1:
-      return {1, 1, 2, 2, MMA_VARIANTS_B1(m8n8k128, b1)};
-    default:
-      llvm_unreachable("Unexpected builtin ID.");
-    }
+      0, \
+      0
+  // clang-format on
+  switch (BuiltinID) {
+  // FP MMA
+  // Note that 'type' argument of MMA_SATF_VARIANTS uses D_C notation, while
+  // NumEltsN of return value are ordered as A,B,C,D.
+  case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
+    return {8, 8, 4, 4, {{MMA_SATF_VARIANTS(m16n16k16, f16_f16)}}};
+  case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
+    return {8, 8, 4, 8, {{MMA_SATF_VARIANTS(m16n16k16, f32_f16)}}};
+  case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
+    return {8, 8, 8, 4, {{MMA_SATF_VARIANTS(m16n16k16, f16_f32)}}};
+  case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
+    return {8, 8, 8, 8, {{MMA_SATF_VARIANTS(m16n16k16, f32_f32)}}};
+  case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
+    return {8, 8, 4, 4, {{MMA_SATF_VARIANTS(m32n8k16, f16_f16)}}};
+  case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
+    return {8, 8, 4, 8, {{MMA_SATF_VARIANTS(m32n8k16, f32_f16)}}};
+  case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
+    return {8, 8, 8, 4, {{MMA_SATF_VARIANTS(m32n8k16, f16_f32)}}};
+  case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
+    return {8, 8, 8, 8, {{MMA_SATF_VARIANTS(m32n8k16, f32_f32)}}};
+  case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
+    return {8, 8, 4, 4, {{MMA_SATF_VARIANTS(m8n32k16, f16_f16)}}};
+  case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
+    return {8, 8, 4, 8, {{MMA_SATF_VARIANTS(m8n32k16, f32_f16)}}};
+  case NVPTX::BI__hmma_m8n32k16_mma_f16f32:
+    return {8, 8, 8, 4, {{MMA_SATF_VARIANTS(m8n32k16, f16_f32)}}};
+  case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
+    return {8, 8, 8, 8, {{MMA_SATF_VARIANTS(m8n32k16, f32_f32)}}};
+
+  // Integer MMA
+  case NVPTX::BI__imma_m16n16k16_mma_s8:
+    return {2, 2, 8, 8, {{MMA_SATF_VARIANTS(m16n16k16, s8)}}};
+  case NVPTX::BI__imma_m16n16k16_mma_u8:
+    return {2, 2, 8, 8, {{MMA_SATF_VARIANTS(m16n16k16, u8)}}};
+  case NVPTX::BI__imma_m32n8k16_mma_s8:
+    return {4, 1, 8, 8, {{MMA_SATF_VARIANTS(m32n8k16, s8)}}};
+  case NVPTX::BI__imma_m32n8k16_mma_u8:
+    return {4, 1, 8, 8, {{MMA_SATF_VARIANTS(m32n8k16, u8)}}};
+  case NVPTX::BI__imma_m8n32k16_mma_s8:
+    return {1, 4, 8, 8, {{MMA_SATF_VARIANTS(m8n32k16, s8)}}};
+  case NVPTX::BI__imma_m8n32k16_mma_u8:
+    return {1, 4, 8, 8, {{MMA_SATF_VARIANTS(m8n32k16, u8)}}};
+
+  // Sub-integer MMA
+  case NVPTX::BI__imma_m8n8k32_mma_s4:
+    return {1, 1, 2, 2, {{MMA_VARIANTS_I4(m8n8k32, s4)}}};
+  case NVPTX::BI__imma_m8n8k32_mma_u4:
+    return {1, 1, 2, 2, {{MMA_VARIANTS_I4(m8n8k32, u4)}}};
+  case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1:
+    return {1, 1, 2, 2, {{MMA_VARIANTS_B1(m8n8k128, b1)}}};
+
+  // Double MMA
+  case NVPTX::BI__dmma_m8n8k4_mma_f64:
+    return {1, 1, 2, 2, {{MMA_VARIANTS(m8n8k4, f64)}}};
+
+  // Alternate FP MMA
+  case NVPTX::BI__mma_bf16_m16n16k16_mma_f32:
+    return {4, 4, 8, 8, {{MMA_VARIANTS(m16n16k16, bf16)}}};
+  case NVPTX::BI__mma_bf16_m8n32k16_mma_f32:
+    return {2, 8, 8, 8, {{MMA_VARIANTS(m8n32k16, bf16)}}};
+  case NVPTX::BI__mma_bf16_m32n8k16_mma_f32:
+    return {8, 2, 8, 8, {{MMA_VARIANTS(m32n8k16, bf16)}}};
+  case NVPTX::BI__mma_tf32_m16n16k8_mma_f32:
+    return {4, 4, 8, 8, {{MMA_VARIANTS(m16n16k8, tf32)}}};
+  default:
+    llvm_unreachable("Unexpected builtin ID.");
+  }
 #undef MMA_VARIANTS
+#undef MMA_SATF_VARIANTS
 #undef MMA_VARIANTS_I4
 #undef MMA_VARIANTS_B1
 }
@@ -16844,7 +16898,20 @@ CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
   case NVPTX::BI__bmma_m8n8k128_ld_a_b1:
   case NVPTX::BI__bmma_m8n8k128_ld_b_b1:
   case NVPTX::BI__bmma_m8n8k128_ld_c:
-  {
+  // Double MMA loads.
+  case NVPTX::BI__dmma_m8n8k4_ld_a:
+  case NVPTX::BI__dmma_m8n8k4_ld_b:
+  case NVPTX::BI__dmma_m8n8k4_ld_c:
+  // Alternate float MMA loads.
+  case NVPTX::BI__mma_bf16_m16n16k16_ld_a:
+  case NVPTX::BI__mma_bf16_m16n16k16_ld_b:
+  case NVPTX::BI__mma_bf16_m8n32k16_ld_a:
+  case NVPTX::BI__mma_bf16_m8n32k16_ld_b:
+  case NVPTX::BI__mma_bf16_m32n8k16_ld_a:
+  case NVPTX::BI__mma_bf16_m32n8k16_ld_b:
+  case NVPTX::BI__mma_tf32_m16n16k8_ld_a:
+  case NVPTX::BI__mma_tf32_m16n16k8_ld_b:
+  case NVPTX::BI__mma_tf32_m16n16k8_ld_c: {
     Address Dst = EmitPointerWithAlignment(E->getArg(0));
     Value *Src = EmitScalarExpr(E->getArg(1));
     Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -16889,7 +16956,9 @@ CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
   case NVPTX::BI__imma_m32n8k16_st_c_i32:
   case NVPTX::BI__imma_m8n32k16_st_c_i32:
   case NVPTX::BI__imma_m8n8k32_st_c_i32:
-  case NVPTX::BI__bmma_m8n8k128_st_c_i32: {
+  case NVPTX::BI__bmma_m8n8k128_st_c_i32:
+  case NVPTX::BI__dmma_m8n8k4_st_c_f64:
+  case NVPTX::BI__mma_m16n16k8_st_c_f32: {
     Value *Dst = EmitScalarExpr(E->getArg(0));
     Address Src = EmitPointerWithAlignment(E->getArg(1));
     Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -16941,7 +17010,12 @@ CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
   case NVPTX::BI__imma_m8n32k16_mma_u8:
   case NVPTX::BI__imma_m8n8k32_mma_s4:
   case NVPTX::BI__imma_m8n8k32_mma_u4:
-  case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: {
+  case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1:
+  case NVPTX::BI__dmma_m8n8k4_mma_f64:
+  case NVPTX::BI__mma_bf16_m16n16k16_mma_f32:
+  case NVPTX::BI__mma_bf16_m8n32k16_mma_f32:
+  case NVPTX::BI__mma_bf16_m32n8k16_mma_f32:
+  case NVPTX::BI__mma_tf32_m16n16k8_mma_f32: {
     Address Dst = EmitPointerWithAlignment(E->getArg(0));
     Address SrcA = EmitPointerWithAlignment(E->getArg(1));
     Address SrcB = EmitPointerWithAlignment(E->getArg(2));

diff  --git a/clang/test/CodeGen/builtins-nvptx-mma.cu b/clang/test/CodeGen/builtins-nvptx-mma.cu
index cc31f6f4779a5..7e9bac86792d2 100644
--- a/clang/test/CodeGen/builtins-nvptx-mma.cu
+++ b/clang/test/CodeGen/builtins-nvptx-mma.cu
@@ -3,21 +3,20 @@
 // *** DO NOT EDIT ***
 //
 //  This test has been automatically generated by
-//  builtins-nvtx-mma.py --ptx=63 --gpu-arch=75
+//  builtins-nvtx-mma.py --ptx=70 --gpu-arch=80
 //
-// Make sure we can handle all builtins available on sm_75 with PTX63
-// RUN: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_75 \
-// RUN:            -fcuda-is-device -target-feature +ptx63 \
-// RUN:            -DPTX=63 -DSM=75 \
+// Make sure we can handle all builtins available on sm_80 with PTX70
+// RUN: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_80 \
+// RUN:            -fcuda-is-device -target-feature +ptx70 \
+// RUN:            -DPTX=70 -DSM=80 \
 // RUN:            -S -emit-llvm -o - -x cuda %s \
-// RUN:   | FileCheck -check-prefixes=CHECK_PTX61_SM70,CHECK_PTX63_SM75,CHECK_PTX63_SM72,CHECK_PTX60_SM70 %s
+// RUN:   | FileCheck -check-prefixes=CHECK_PTX70_SM80,CHECK_PTX60_SM70,CHECK_PTX63_SM72,CHECK_PTX61_SM70,CHECK_PTX63_SM75 %s
 // Verify that all builtins have correct constraints.
 // RUN: %clang_cc1 -triple nvptx-unknown-unknown \
 // RUN:   -target-cpu sm_60 -target-feature +ptx42 \
-// RUN:   -DPTX=63 -DSM=75 -fcuda-is-device -S -o /dev/null -x cuda \
+// RUN:   -DPTX=70 -DSM=80 -fcuda-is-device -S -o /dev/null -x cuda \
 // RUN:   -verify %s
 
-
 #if !defined(CUDA_VERSION)
 #define __device__ __attribute__((device))
 #define __global__ __attribute__((global))
@@ -29,8 +28,8 @@ typedef unsigned long long uint64_t;
 
 // CHECK-LABEL: test_wmma_buitins
 __device__ void test_wmma_buitins(int *src, int *dst,
-                                  float *fsrc, float *fdst, int ldm) {
-
+                                  float *fsrc, float *fdst,
+                                  double *dsrc, double *ddst, int ldm) {
 
 #if (PTX >= 60) && (SM >= 70)
 
@@ -751,5 +750,153 @@ __device__ void test_wmma_buitins(int *src, int *dst,
   // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.mma.row.col.u4.satfinite
   // expected-error-re at +1 {{'__imma_m8n8k32_mma_u4' needs target feature (sm_75{{.*}},(ptx63{{.*}}}}
   __imma_m8n8k32_mma_u4(dst, src, src, src, 1, 1);
-#endif // (PTX >= 63) && (SM >= 75) 
+#endif // (PTX >= 63) && (SM >= 75)
+
+#if (PTX >= 70) && (SM >= 80)
+
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m16n16k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m16n16k16_ld_a(dst, src, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m16n16k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m16n16k16_ld_a(dst, src, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m16n16k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m16n16k16_ld_b(dst, src, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m16n16k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m16n16k16_ld_b(dst, src, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32
+  // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_tf32_m16n16k8_ld_a(dst, src, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32
+  // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_tf32_m16n16k8_ld_a(dst, src, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32
+  // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_tf32_m16n16k8_ld_b(dst, src, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.b.row.stride.tf32
+  // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_tf32_m16n16k8_ld_b(dst, src, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.c.col.stride.f32
+  // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_tf32_m16n16k8_ld_c(fdst, fsrc, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.c.row.stride.f32
+  // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_tf32_m16n16k8_ld_c(fdst, fsrc, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.store.d.col.stride.f32
+  // expected-error-re at +1 {{'__mma_m16n16k8_st_c_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_m16n16k8_st_c_f32(fdst, fsrc, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.store.d.row.stride.f32
+  // expected-error-re at +1 {{'__mma_m16n16k8_st_c_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_m16n16k8_st_c_f32(fdst, fsrc, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m32n8k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m32n8k16_ld_a(dst, src, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m32n8k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m32n8k16_ld_a(dst, src, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m32n8k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m32n8k16_ld_b(dst, src, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m32n8k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m32n8k16_ld_b(dst, src, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m8n32k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m8n32k16_ld_a(dst, src, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m8n32k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m8n32k16_ld_a(dst, src, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m8n32k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m8n32k16_ld_b(dst, src, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m8n32k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m8n32k16_ld_b(dst, src, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_ld_a(ddst, dsrc, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_ld_a(ddst, dsrc, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_ld_b(ddst, dsrc, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_ld_b(ddst, dsrc, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_ld_c(ddst, dsrc, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_ld_c(ddst, dsrc, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_st_c_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_st_c_f64(ddst, dsrc, ldm, 1);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_st_c_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_st_c_f64(ddst, dsrc, ldm, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 3, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.row.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 2, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.col.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 1, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 0, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32
+  // expected-error-re at +1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 3, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.col.row.tf32
+  // expected-error-re at +1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 2, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.row.col.tf32
+  // expected-error-re at +1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 1, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32
+  // expected-error-re at +1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 0, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 3, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.col.row.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 2, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.row.col.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 1, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 0, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 3, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.col.row.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 2, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.row.col.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 1, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16
+  // expected-error-re at +1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 0, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 3, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.col.row.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 2, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.row.col.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 1, 0);
+  // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64
+  // expected-error-re at +1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+  __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 0, 0);
+#endif // (PTX >= 70) && (SM >= 80)
 }

diff  --git a/clang/test/CodeGen/builtins-nvptx-mma.py b/clang/test/CodeGen/builtins-nvptx-mma.py
index 1b395fc4f33b1..2ffc21b12fb06 100644
--- a/clang/test/CodeGen/builtins-nvptx-mma.py
+++ b/clang/test/CodeGen/builtins-nvptx-mma.py
@@ -47,7 +47,13 @@ def make_ldst_ops(geoms, frags, types):
           in product(geoms, frags, types)]
 
 def get_mma_ops():
-  return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+  return (make_mma_ops(["m16n16k8"],
+                       ["tf32"], [], ["f32"], []) +
+          make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+                       ["bf16"], [], ["f32"], []) +
+          make_mma_ops(["m8n8k4"],
+                       ["f64"], [], ["f64"], []) +
+          make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
                        ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
           make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
                        ["s8", "u8"], [], ["s32"], []) +
@@ -55,14 +61,18 @@ def get_mma_ops():
                        ["s4", "u4"], [], ["s32"], []) +
           make_mma_ops(["m8n8k128"],
                        ["b1"], [], ["s32"], []))
+
 def get_ldst_ops():
   return (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
-                        ["a", "b"], ["f16", "u8", "s8"]) +
+                        ["a", "b"], ["f16", "u8", "s8", "bf16"]) +
           make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
                         ["c", "d"], ["f16", "f32", "s32"]) +
           make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
           make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
-          make_ldst_ops(["m8n8k32", "m8n8k128"],  ["c", "d"], ["s32"]))
+          make_ldst_ops(["m8n8k32", "m8n8k128"],  ["c", "d"], ["s32"]) +
+          make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) +
+          make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) +
+          make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
 
 def is_geom_supported(geom):
   # geometries for FP and ints.
@@ -73,6 +83,8 @@ def is_geom_supported(geom):
     return ptx_version >= 63 and gpu_arch >= 75
   if geom == "m16n16k16":
     return ptx_version >= 60
+  if geom in ["m16n16k8", "m8n8k4"]:
+    return ptx_version >= 70 and gpu_arch >= 80
   assert(False) # Unexpected geometry.
 
 def is_type_supported(ptx_type):
@@ -80,16 +92,24 @@ def is_type_supported(ptx_type):
     return ptx_version >= 63 and gpu_arch >= 72
   if ptx_type in ["s4", "u4", "b1"]:
     return ptx_version >= 63 and gpu_arch >= 75
+  if ptx_type in ["bf16", "tf32", "f64"]:
+    return ptx_version >= 70 and gpu_arch >= 80
   return ptx_version >= 60 and gpu_arch >= 70
 
+def is_rnd_supported(op):
+  # rnd is only supported for FP64 WMMA
+  return op.a.ptx_type == "f64"
+
 def is_mma_variant_supported(op, layout_a, layout_b, satf):
   if not (is_type_supported(op.a.ptx_type)
           and is_geom_supported(op.a.geom)):
     return False
-  # sub-integer require row/col layout, and no satf.
+
+  if satf and not op.a.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
+    return False
+
+  # sub-integer types require row/col layout.
   if op.a.ptx_type in ["s4", "u4", "b1"]:
-    if op.a.ptx_type == "b1" and satf:
-      return False
     return layout_a == "row" and layout_b == "col"
   return True
 
@@ -98,7 +118,7 @@ def is_ldst_variant_supported(frag, layout):
           and is_geom_supported(frag.geom)):
     return False
   if frag.ptx_type in ["s4", "u4", "b1"]:
-    # sub-integer require sm_75 and ptx63, row/col layout for a/b.
+    # sub-integer types require sm_75 and ptx63, row/col layout for a/b.
     return ((frag.frag == "a" and layout == "row")
             or (frag.frag == "b" and layout == "col")
             or frag.frag in ["c", "d"])
@@ -109,12 +129,21 @@ def get_builtin_prefix(frag):
   if frag.geom in ["m16n16k16", "m32n8k16", "m8n32k16"]:
     if frag.ptx_type in ["f16", "f32"]:
       prefix = "__hmma"
+    elif frag.ptx_type == "bf16":
+      prefix = "__mma_bf16"
     else:
       prefix = "__imma"
   elif frag.geom == "m8n8k32":
     prefix = "__imma" # sub-integers
   elif frag.geom == "m8n8k128":
     prefix = "__bmma"
+  elif frag.geom == "m8n8k4":
+    prefix = "__dmma"
+  elif frag.geom == "m16n16k8":
+    if frag.ptx_type == "f32":
+      prefix = "__mma"
+    else:
+      prefix = "__mma_tf32"
   assert prefix
   return prefix
 
@@ -123,10 +152,13 @@ def get_ldst_builtin_name(frag):
 
   if prefix == "__hmma":
     suffix = "" if frag.frag in ["a","b"] else frag.ptx_type
-  elif prefix in ["__imma", "__bmma"]:
-    suffix = "" if frag.frag in ["c"] else frag.ptx_type
+  elif prefix in ["__dmma", "__mma_bf16", "__mma_tf32"]:
+    suffix = "" if frag.frag in ["a","b","c"] else frag.ptx_type
+  else:
+    suffix = "" if frag.frag == "c" else frag.ptx_type
     if suffix == "s32":
       suffix = "i32"
+
   if frag.frag == "d":
     ifrag = "c"
     op = "st"
@@ -143,6 +175,8 @@ def get_mma_builtin_name(op):
 
   if prefix == "__hmma":
     suffix = op.d.ptx_type + op.c.ptx_type
+  elif prefix in ["__mma_bf16", "__mma_tf32"]:
+    suffix = op.d.ptx_type
   else:
     suffix = op.a.ptx_type
 
@@ -151,8 +185,9 @@ def get_mma_builtin_name(op):
                              suffix)
   return name
 
-
 def get_required_sm(frag):
+  if frag.ptx_type in ["f64", "bf16", "tf32"]:
+    return 80
   if frag.ptx_type in ["u4", "s4", "b1"]:
     return 75
   if frag.ptx_type in ["s8", "u8"]:
@@ -163,18 +198,34 @@ def get_required_sm(frag):
     else:                       # s8/u8
       return 72
   if frag.ptx_type in ["f16", "f32"]:
-    return 70
+    if frag.geom == "m16n16k8":
+      return 80
+    else:
+      return 70
   assert(False)
 
 def get_required_ptx(frag):
+  if frag.ptx_type in ["f64", "bf16", "tf32"]:
+    return 70
   if frag.ptx_type in ["f16", "f32"]:
-    return 60 if frag.geom == "m16n16k16" else 61
+    if frag.geom == "m16n16k16":
+      return 60
+    if frag.geom == "m16n16k8":
+      return 70
+    return 61
   return 63
 
+def get_src_dst_prefix(ptx_type):
+  if ptx_type == "f32":
+    return "f"
+  if ptx_type == "f64":
+    return "d"
+  return ""
+
 def gen_wmma_ldst_tests(results):
   load_template = """
   // CHECK${check_suffix}: call {{.*}} @${intrinsic}
-  // expected-error-re at +1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}}
+  // expected-error-re at +1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
   ${builtin}(${dst}, ${src}, ldm, ${blayout});
 """.rstrip()
   intrinsic_template = "llvm.nvvm.wmma.${geom}.${op}.${frag}.${ilayout}.stride.${itype}"
@@ -184,7 +235,7 @@ def gen_wmma_ldst_tests(results):
     if not is_ldst_variant_supported(frag, layout):
       continue
 
-    is_fp = frag.ptx_type  == "f32"
+    src_dst_prefix = get_src_dst_prefix(frag.ptx_type)
     min_sm = get_required_sm(frag)
     min_ptx = get_required_ptx(frag)
     params = {
@@ -192,8 +243,8 @@ def gen_wmma_ldst_tests(results):
         "builtin" : get_ldst_builtin_name(frag),
         "min_ptx" : min_ptx,
         "min_sm" : min_sm,
-        "dst": "fdst" if is_fp else "dst",
-        "src": "fsrc" if is_fp else "src",
+        "dst": src_dst_prefix + "dst",
+        "src": src_dst_prefix + "src",
         "blayout" : 0 if layout == "row" else 1,
         "intrinsic" : Template(intrinsic_template).substitute({
             "frag" : frag.frag,
@@ -208,12 +259,12 @@ def gen_wmma_ldst_tests(results):
   return results
 
 def mma_signature(op):
-  if op.a.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
-    # int and sub-int ops are identified by input type.
-    return op.a.ptx_type
-  else:
-    # the rest are FP ops identified by accumulator & result type.
+  if op.a.ptx_type == "f16":
+    # FP16 ops identified by accumulator & result type.
     return "%s.%s" % (op.d.ptx_type, op.c.ptx_type)
+  else:
+    # other ops are identified by input type.
+    return op.a.ptx_type
 
 # Get numeric value for rowcol parameter of the builtin
 # AFAICT it uses the encoding accepted by NVVM intrinsics:
@@ -229,8 +280,8 @@ def get_ilayout(a, b):
 def gen_wmma_mma_tests(results):
   mma_template = """
   // CHECK${check_suffix}: call {{.*}} @${intrinsic}
-  // expected-error-re at +1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}}
-  ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_isatf});
+  // expected-error-re at +1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
+  ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_satf});
 """.rstrip()
   intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
 
@@ -243,9 +294,9 @@ def gen_wmma_mma_tests(results):
     if not is_mma_variant_supported(op, alayout, blayout, satf):
       continue
 
-    a_is_fp = op.a.ptx_type  == "f32"
-    c_is_fp = op.c.ptx_type  == "f32"
-    d_is_fp = op.d.ptx_type  == "f32"
+    asrc_prefix = get_src_dst_prefix(op.a.ptx_type)
+    csrc_prefix = get_src_dst_prefix(op.c.ptx_type)
+    ddst_prefix = get_src_dst_prefix(op.d.ptx_type)
     min_sm = get_required_sm(op.a)
     min_ptx = get_required_ptx(op.a)
     if op.a.ptx_type == "b1": # .b1 MMA has no satf argument.
@@ -257,11 +308,11 @@ def gen_wmma_mma_tests(results):
         "builtin" : get_mma_builtin_name(op),
         "min_ptx" : min_ptx,
         "min_sm" : min_sm,
-        "dst": "fdst" if d_is_fp else "dst",
-        "asrc": "fsrc" if a_is_fp else "src",
-        "csrc": "fsrc" if c_is_fp else "src",
+        "dst": ddst_prefix + "dst",
+        "asrc": asrc_prefix + "src",
+        "csrc": csrc_prefix + "src",
         "ilayout" : get_ilayout(alayout, blayout),
-        "maybe_isatf" : isatf_arg,
+        "maybe_satf" : isatf_arg,
         "intrinsic" : Template(intrinsic_template).substitute({
             "geom"  : op.a.geom,
             "alayout" : alayout,
@@ -322,7 +373,8 @@ def supported_variants(ptx, sm, results):
 
 // CHECK-LABEL: test_wmma_buitins
 __device__ void test_wmma_buitins(int *src, int *dst,
-                                  float *fsrc, float *fdst, int ldm) {
+                                  float *fsrc, float *fdst,
+                                  double *dsrc, double *ddst, int ldm) {
 """);
 
   for (ptx, sm), tests in sorted(results.items()):

diff  --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 71e31b14f4c3b..3ce9dfb1bb807 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -52,13 +52,27 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
   string gft = Geom#":"#Frag#":"#ptx_elt_type;
   string ft = frag#":"#ptx_elt_type;
   list<LLVMType> regs = !cond(
-    // mma.sync.m8n8k4 uses smaller a/b fragments than wmma fp ops
+    // mma fp ops use smaller fragments than wmma fp ops
     !eq(gft,"m8n8k4:a:f16") : !listsplat(llvm_v2f16_ty, 2),
     !eq(gft,"m8n8k4:b:f16") : !listsplat(llvm_v2f16_ty, 2),
-
-    // fp16 -> fp16/fp32 @  m16n16k16/m8n32k16/m32n8k16
-    // All currently supported geometries use the same fragment format,
-    // so we only need to consider {fragment, type}.
+    !eq(gft,"m16n8k8:a:f16") : !listsplat(llvm_v2f16_ty, 2),
+    !eq(gft,"m16n8k8:b:f16") : [llvm_v2f16_ty],
+    !eq(gft,"m16n8k8:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+    !eq(gft,"m16n8k8:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+    !eq(gft,"m16n8k8:c:f32") : !listsplat(llvm_float_ty, 4),
+    !eq(gft,"m16n8k8:d:f32") : !listsplat(llvm_float_ty, 4),
+    !eq(gft,"m16n8k16:a:f16") : !listsplat(llvm_v2f16_ty, 4),
+    !eq(gft,"m16n8k16:b:f16") : !listsplat(llvm_v2f16_ty, 2),
+    !eq(gft,"m16n8k16:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+    !eq(gft,"m16n8k16:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+    !eq(gft,"m16n8k16:c:f32") : !listsplat(llvm_float_ty, 4),
+    !eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
+    !eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
+    !eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
+
+    // wmma fp16 -> fp16/fp32 @  m16n16k16/m8n32k16/m32n8k16
+    // All other supported geometries use the same fragment format for f32 and
+    // f16, so we only need to consider {fragment, type}.
     !eq(ft,"a:f16") : !listsplat(llvm_v2f16_ty, 8),
     !eq(ft,"b:f16") : !listsplat(llvm_v2f16_ty, 8),
     !eq(ft,"c:f16") : !listsplat(llvm_v2f16_ty, 4),
@@ -66,7 +80,36 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
     !eq(ft,"c:f32") : !listsplat(llvm_float_ty, 8),
     !eq(ft,"d:f32") : !listsplat(llvm_float_ty, 8),
 
-    // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+    // wmma tf32 -> s32 @ m16n16k8
+    !eq(gft,"m16n16k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n16k8:b:tf32") : !listsplat(llvm_i32_ty, 4),
+
+    // mma tf32 -> s32 @ m16n16k8/m16n8k8
+    !eq(gft,"m16n8k4:a:tf32") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k4:b:tf32") : [llvm_i32_ty],
+    !eq(gft,"m16n8k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k8:b:tf32") : !listsplat(llvm_i32_ty, 2),
+
+    !eq(gft,"m8n8k4:a:f64") : [llvm_double_ty],
+    !eq(gft,"m8n8k4:b:f64") : [llvm_double_ty],
+    !eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
+    !eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
+
+    // wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+    !eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m8n32k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m8n32k16:b:bf16") : !listsplat(llvm_i32_ty, 8),
+    !eq(gft,"m32n8k16:a:bf16") : !listsplat(llvm_i32_ty, 8),
+    !eq(gft,"m32n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
+
+    // mma bf16 -> s32 @ m16n8k16/m16n8k8
+    !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],
+
+    // wmma u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
     !eq(gft,"m16n16k16:a:u8") : !listsplat(llvm_i32_ty, 2),
     !eq(gft,"m16n16k16:a:s8") : !listsplat(llvm_i32_ty, 2),
     !eq(gft,"m16n16k16:b:u8") : !listsplat(llvm_i32_ty, 2),
@@ -88,17 +131,65 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
     !eq(gft,"m32n8k16:c:s32") : !listsplat(llvm_i32_ty, 8),
     !eq(gft,"m32n8k16:d:s32") : !listsplat(llvm_i32_ty, 8),
 
-    // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
-    !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
+    // mma u8/s8 -> s32 @ m8n8k16/m16n8k16/m16n8k32
+    !eq(gft,"m8n8k16:a:u8") : [llvm_i32_ty],
+    !eq(gft,"m8n8k16:a:s8") : [llvm_i32_ty],
+    !eq(gft,"m8n8k16:b:u8") : [llvm_i32_ty],
+    !eq(gft,"m8n8k16:b:s8") : [llvm_i32_ty],
+    !eq(gft,"m8n8k16:c:s32") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m8n8k16:d:s32") : !listsplat(llvm_i32_ty, 2),
+
+    !eq(gft,"m16n8k16:a:u8") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k16:a:s8") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k16:b:u8") : [llvm_i32_ty],
+    !eq(gft,"m16n8k16:b:s8") : [llvm_i32_ty],
+    !eq(gft,"m16n8k16:c:s32") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k16:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+    !eq(gft,"m16n8k32:a:u8") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k32:a:s8") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k32:b:u8") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k32:b:s8") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+    // wmma/mma u4/s4 -> s32 @ m8n8k32 (u4/s4)
     !eq(gft,"m8n8k32:a:u4") : [llvm_i32_ty],
     !eq(gft,"m8n8k32:a:s4") : [llvm_i32_ty],
-    !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
     !eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty],
     !eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty],
-    !eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2),
-    !eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2),
     !eq(gft,"m8n8k32:c:s32") : !listsplat(llvm_i32_ty, 2),
     !eq(gft,"m8n8k32:d:s32") : !listsplat(llvm_i32_ty, 2),
+
+    !eq(gft,"m16n8k32:a:u4") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k32:a:s4") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k32:b:u4") : [llvm_i32_ty],
+    !eq(gft,"m16n8k32:b:s4") : [llvm_i32_ty],
+    !eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+    !eq(gft,"m16n8k64:a:u4") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k64:a:s4") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k64:b:u4") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k64:b:s4") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+    // wmma/mma b1 -> s32 @ m8n8k128(b1)
+    !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
+    !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
+    !eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2),
+
+    !eq(gft,"m16n8k128:a:b1") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k128:b:b1") : [llvm_i32_ty],
+    !eq(gft,"m16n8k128:c:s32") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k128:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+    !eq(gft,"m16n8k256:a:b1") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k256:b:b1") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m16n8k256:c:s32") : !listsplat(llvm_i32_ty, 4),
+    !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
   );
 }
 
@@ -125,35 +216,40 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
 
 class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
   list<WMMA_REGS> id_frags = !cond(
-     // int and sub-int ops are identified by input type.
-     !eq(A.ptx_elt_type, "s8") : [A],
-     !eq(A.ptx_elt_type, "u8") : [A],
-     !eq(A.ptx_elt_type, "s4") : [A],
-     !eq(A.ptx_elt_type, "u4") : [A],
-     !eq(A.ptx_elt_type, "b1") : [A],
-     // the rest are FP ops identified by accumulator & result type.
-     true: [D, C]
+     // FP16 ops are identified by accumulator & result type.
+     !eq(A.ptx_elt_type, "f16") : [D, C],
+     // other ops are identified by input types.
+     !ne(A.ptx_elt_type, B.ptx_elt_type): [A, B],
+     true: [A]
      );
    string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type));
 }
 
-class WMMA_NAME_MMA<string ALayout, string BLayout, int Satfinite,
-                    WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd,
+                WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
   string signature = MMA_SIGNATURE<A, B, C, D>.ret;
-  string llvm = !if(
-    !eq(A.geom, "m8n8k4"),
-        "llvm.nvvm.mma.m8n8k4"
-           # "." # ALayout
-           # "." # BLayout
-           # signature,
-        "llvm.nvvm.wmma."
-           # A.geom
-           # ".mma"
-           # "." # ALayout
-           # "." # BLayout
-           # signature
-           # !if(Satfinite, ".satfinite", ""));
+  string llvm = "llvm.nvvm.wmma."
+                # A.geom
+                # ".mma"
+                # "." # ALayout
+                # "." # BLayout
+                # !if(!ne(Rnd, ""), !strconcat(".", Rnd), "")
+                # signature
+                # !if(Satfinite, ".satfinite", "");
+
+  string record = !subst(".", "_",
+                  !subst("llvm.", "int_", llvm));
+}
 
+class MMA_NAME<string ALayout, string BLayout, int Satfinite,
+               WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string llvm = "llvm.nvvm.mma."
+                # A.geom
+                # "." # ALayout
+                # "." # BLayout
+                # !if(Satfinite, ".satfinite", "")
+                # signature;
   string record = !subst(".", "_",
                   !subst("llvm.", "int_", llvm));
 }
@@ -188,14 +284,18 @@ class MMA_LDST_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
    list<string> ops = !foreach(x, ret, x.gft);
 }
 
-
-
 // Creates list of valid combinations of fragments. This is the master list that
 // drives generation of corresponding intrinsics and instructions.
 class NVVM_MMA_OPS<int _ = 0> {
-  list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
+  list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
+            ["m16n16k8"],
+            ["tf32"], [], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> bf16_wmma_ops = MMA_OPS<
+            ["m16n16k16", "m32n8k16", "m8n32k16"],
+            ["bf16"], [], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> f64_wmma_ops = MMA_OPS<
             ["m8n8k4"],
-            ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
+            ["f64"], [], ["f64"], []>.ret;
   list<list<WMMA_REGS>> fp_wmma_ops = MMA_OPS<
             ["m16n16k16", "m32n8k16", "m8n32k16"],
             ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
@@ -208,16 +308,50 @@ class NVVM_MMA_OPS<int _ = 0> {
   list<list<WMMA_REGS>> bit_wmma_ops = MMA_OPS<
             ["m8n8k128"],
             ["b1"], [], ["s32"], []>.ret;
+  list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
+            tf32_wmma_ops, bf16_wmma_ops, f64_wmma_ops,
+            fp_wmma_ops, int_wmma_ops, subint_wmma_ops, bit_wmma_ops);
+
+  list<list<WMMA_REGS>> tf32_mma_ops = MMA_OPS<
+            ["m16n8k4", "m16n8k8"],
+            ["tf32"], [], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> bf16_mma_ops = MMA_OPS<
+            ["m16n8k16", "m16n8k8"],
+            ["bf16"], [], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
+            ["m8n8k4"],
+            ["f64"], [], ["f64"], []>.ret;
+  list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
+            ["m8n8k4", "m16n8k8", "m16n8k16"],
+            ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
+  list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
+            ["m8n8k16", "m16n8k16", "m16n8k32"],
+            ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
+  list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
+            ["m8n8k32", "m16n8k32", "m16n8k64"],
+            ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
+  list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS<
+            ["m8n8k128", "m16n8k128", "m16n8k256"],
+            ["b1"], [], ["s32"], []>.ret;
   list<list<WMMA_REGS>> all_mma_ops = !listconcat(
-            fp_mma_ops, fp_wmma_ops, int_wmma_ops,
-            subint_wmma_ops, bit_wmma_ops);
+            tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
+            fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
 
   list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
             ["m16n16k16", "m32n8k16", "m8n32k16"],
-            ["a", "b"], ["f16", "u8", "s8"]>.ret;
+            ["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
   list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS<
             ["m16n16k16", "m32n8k16", "m8n32k16"],
             ["c", "d"], ["f16", "f32", "s32"]>.ret;
+  list<WMMA_REGS> ldst_tf32_ab_ops = MMA_LDST_OPS<
+            ["m16n16k8"],
+            ["a", "b"], ["tf32"]>.ret;
+  list<WMMA_REGS> ldst_tf32_cd_ops = MMA_LDST_OPS<
+            ["m16n16k8"],
+            ["c", "d"], ["f32"]>.ret;
+  list<WMMA_REGS> ldst_f64_abcd_ops = MMA_LDST_OPS<
+            ["m8n8k4"],
+            ["a", "b", "c", "d"], ["f64"]>.ret;
   list<WMMA_REGS> ldst_subint_ab_ops = MMA_LDST_OPS<
             ["m8n8k32"], ["a", "b"], ["s4","u4"]>.ret;
   list<WMMA_REGS> ldst_bit_ab_ops = MMA_LDST_OPS<
@@ -225,6 +359,9 @@ class NVVM_MMA_OPS<int _ = 0> {
   list<WMMA_REGS> ldst_subint_cd_ops = MMA_LDST_OPS<
             ["m8n8k32", "m8n8k128"],  ["c", "d"], ["s32"]>.ret;
   list<WMMA_REGS> all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops,
+                                             ldst_tf32_ab_ops,
+                                             ldst_tf32_cd_ops,
+                                             ldst_f64_abcd_ops,
                                              ldst_subint_ab_ops,
                                              ldst_bit_ab_ops,
                                              ldst_subint_cd_ops);
@@ -235,69 +372,110 @@ class NVVM_MMA_OPS<int _ = 0> {
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
 
-// Returns true if this combination of layout/satf is supported; false otherwise.
-// MMA ops must provide all parameters. Loads and stores -- only frags and layout_a.
-// The class is used to prevent generation of records for the unsupported variants.
+
+// Returns true if this combination of fragment and layout for WMMA load/store
+// ops is supported; false otherwise.
+// E.g.
+// if NVVM_WMMA_LDST_SUPPORTED<...>.ret then
+//   def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_WMMA_LDST_SUPPORTED<WMMA_REGS frag, string layout> {
+  string f = frag.frag;
+  string t = frag.ptx_elt_type;
+
+  bit ret = !cond(
+    // Sub-int load and store requires A fragment to be of row layout and B
+    // fragments to be of column layout.
+    !and(!or(!eq(t, "b1"),
+             !eq(t, "u4"),
+             !eq(t, "s4")),
+         !or(!and(!eq(f, "a"),
+                  !ne(layout, "row")),
+             !and(!eq(f, "b"),
+                  !ne(layout, "col")))) : false,
+    true: true
+  );
+}
+
+// Returns true if this combination of layout/satf/rnd for WMMA ops is
+// supported; false otherwise.
+// E.g.
+// if NVVM_WMMA_SUPPORTED<...>.ret then
+//   def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_WMMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf, string rnd> {
+  // WMMA ops check both layouts.
+  string layout = layout_a # ":" # layout_b;
+  string t = frags[0].ptx_elt_type;
+
+  bit ret = !cond(
+    // only f64 wmma functions support rnd options
+    // any non f64 type that uses a rnd value is invalid
+    !and(!ne(t, "f64"), !ne(rnd, "")) : false,
+
+    // satf is only valid for select types
+    !and(!eq(satf, 1),
+         !ne(t, "s8"),
+         !ne(t, "u8"),
+         !ne(t, "s4"),
+         !ne(t, "u4"),
+         !ne(t, "f16")): false,
+
+    // Sub-int wmma requires row/column layout
+    !and(!or(!eq(t, "s4"),
+             !eq(t, "u4"),
+             !eq(t, "b1")),
+         !ne(layout, "row:col")) : false,
+    true: true
+  );
+}
+
+// Returns true if this combination of layout/satf for MMA ops is supported;
+// false otherwise.
 // E.g.
 // if NVVM_MMA_SUPPORTED<...>.ret then
 //   def : FOO<>; // The record will only be defined for supported ops.
 //
-class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b="-", int satf=-1> {
+class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
   // MMA ops check both layouts.
-  string mma = frags[0].ptx_elt_type
-               # ":" # layout_a
-               # ":" # layout_b;
-  // Load ops only need type/fragment/layout.
-  string ld = frags[0].ptx_elt_type
-               # ":" # frags[0].frag
-               # ":" # layout_a
-               ;
-  string ldf = frags[0].ptx_elt_type
-               # ":" # frags[0].frag
-               ;
-  string t = frags[0].ptx_elt_type;
+  string layout = layout_a # ":" # layout_b;
+  string a_type = frags[0].ptx_elt_type;
+  string b_type = frags[1].ptx_elt_type;
+  string c_type = frags[2].ptx_elt_type;
+  string d_type = frags[3].ptx_elt_type;
+  string geom = frags[0].geom;
 
   // gcd is a shortcut used to identify instructions that depend on
-  // geom+frag_c+frag_d.  Not all instances of this class have all fragments
-  // specified. If there are not enough fragments, the tail evaluates to '?'.
-  string gcd = frags[0].geom
-               # ":"
-               # !if(!eq(!size(frags), 4),
-                     frags[2].ptx_elt_type # frags[3].ptx_elt_type,
-                     "?");
+  // geom+frag_c+frag_d.
+  string gcd = geom # ":" # c_type # d_type;
   bit ret = !cond(
-    // Sub-int MMA only supports fixed A/B layout.
-    // b1 does not support .satf.
-    !eq(mma#":"#satf, "b1:row:col:0") : true,
-    // mma.m8n8k4 has no .satf modifier.
-    !and(!eq(frags[0].geom, "m8n8k4"),
-         !ne(satf, 0)): false,
-
-    // mma.m8n8k4 has no C=f32 D=f16 variant.
+
+    // Limit satf to valid types
+    !and(!eq(satf, 1),
+         !ne(a_type, "s8"),
+         !ne(a_type, "u8"),
+         !ne(a_type, "s4"),
+         !ne(a_type, "u4")): false,
+
+    // m8n8k4 has no C=f32 D=f16 variant.
     !eq(gcd, "m8n8k4:f32f16"): false,
-    !eq(mma, "s4:row:col") : true,
-    !eq(mma, "u4:row:col") : true,
-    !eq(mma, "s4:row:col") : true,
-    !eq(mma, "u4:row:col") : true,
-    // Sub-int load/stores have fixed layout for A and B.
-    !and(!eq(layout_b, "-"), // It's a Load or Store op
-         !or(!eq(ld,  "b1:a:row"),
-             !eq(ld,  "b1:b:col"),
-             !eq(ldf, "b1:c"),
-             !eq(ldf, "b1:d"),
-             !eq(ld, "s4:a:row"),
-             !eq(ld, "s4:b:col"),
-             !eq(ldf, "s4:c"),
-             !eq(ldf, "s4:d"),
-             !eq(ld, "u4:a:row"),
-             !eq(ld, "u4:b:col"),
-             !eq(ldf, "u4:c"),
-             !eq(ldf, "u4:d"))) : true,
-    // All other sub-int ops are not supported.
-    !eq(t, "b1") : false,
-    !eq(t, "s4") : false,
-    !eq(t, "u4") : false,
-    // All other (non sub-int) are OK.
+
+    // only m8n8k4 for f16 does not require row:col layout
+    !and(!ne(layout, "row:col"),
+         !or(!ne(geom, "m8n8k4"),
+             !ne(a_type, "f16"))) : false,
+
+    // m16n8k8 requires A and B to be the same type and C and D to be the same
+    // type.
+    !and(!eq(geom, "m16n8k8"),
+         !or(!ne(a_type, b_type),
+             !ne(c_type, d_type))): false,
+
+    // m16n8k8 requires C and D to be the same type.
+    !and(!eq(geom, "m16n8k8"),
+         !ne(c_type, d_type)): false,
+
+    // All other are OK.
     true: true
   );
 }
@@ -4271,36 +4449,59 @@ class NVVM_WMMA_ST<WMMA_REGS Frag, string Layout, int WithStride>
 foreach layout = ["row", "col"] in {
   foreach stride = [0, 1] in {
     foreach frag = NVVM_MMA_OPS.all_ld_ops in
-      if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
+      if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
         def WMMA_NAME_LDST<"load", frag, layout, stride>.record
              : NVVM_WMMA_LD<frag, layout, stride>;
     foreach frag = NVVM_MMA_OPS.all_st_ops in
-      if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
+      if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
         def WMMA_NAME_LDST<"store", frag, layout, stride>.record
              : NVVM_WMMA_ST<frag, layout, stride>;
   }
 }
 
 // WMMA.MMA
-class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite,
+class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite, string rnd,
                     WMMA_REGS A, WMMA_REGS B,
                     WMMA_REGS C, WMMA_REGS D>
   : Intrinsic<D.regs,
               !listconcat(A.regs, B.regs, C.regs),
               [IntrNoMem],
-              WMMA_NAME_MMA<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
+              WMMA_NAME<ALayout, BLayout, Satfinite, rnd, A, B, C, D>.llvm>;
+
+foreach layout_a = ["row", "col"] in {
+  foreach layout_b = ["row", "col"] in {
+    foreach satf = [0, 1] in {
+      foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
+        foreach op = NVVM_MMA_OPS.all_wmma_ops in {
+          if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
+            def WMMA_NAME<layout_a, layout_b, satf, rnd,
+                              op[0], op[1], op[2], op[3]>.record
+              : NVVM_WMMA_MMA<layout_a, layout_b, satf, rnd,
+                              op[0], op[1], op[2], op[3]>;
+          }
+        } // op
+      } // rnd
+    } // satf
+  } // layout_b
+} // layout_a
+
+// MMA
+class NVVM_MMA<string ALayout, string BLayout, int Satfinite,
+               WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+  : Intrinsic<D.regs,
+              !listconcat(A.regs, B.regs, C.regs),
+              [IntrNoMem],
+              MMA_NAME<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
 
 foreach layout_a = ["row", "col"] in {
   foreach layout_b = ["row", "col"] in {
     foreach satf = [0, 1] in {
       foreach op = NVVM_MMA_OPS.all_mma_ops in {
         if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
-          def WMMA_NAME_MMA<layout_a, layout_b, satf,
-                            op[0], op[1], op[2], op[3]>.record
-            : NVVM_WMMA_MMA<layout_a, layout_b, satf,
-                            op[0], op[1], op[2], op[3]>;
+          def MMA_NAME<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>.record
+            : NVVM_MMA<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>;
         }
-      }
+      } // op
     } // satf
   } // layout_b
 } // layout_a

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d431f20d066f6..d4842c953ce7a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3490,6 +3490,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
   case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
   case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
   case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
   case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
   case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
@@ -3497,7 +3501,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
   case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
   case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
-  case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: {
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v2i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3515,6 +3523,14 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
   case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
   case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
 
   case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
   case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
@@ -3523,7 +3539,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
   case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
   case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
-  case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: {
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v4i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3603,7 +3627,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
   case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
   case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
-  case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: {
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v8f32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3613,6 +3641,16 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
+
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
+
   case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
   case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
   case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
@@ -3651,6 +3689,37 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
+  case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
+  case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
+  case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
+  case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
+
+  case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
+  case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
+  case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
+  case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::f64;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOLoad;
+    Info.align = Align(8);
+    return true;
+  }
+
+  case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
+  case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
+  case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
+  case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::v2f64;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOLoad;
+    Info.align = Align(16);
+    return true;
+  }
+
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
@@ -3683,7 +3752,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
   case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
   case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
-  case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: {
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
+  case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
+  case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
+  case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v8f32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3731,6 +3804,19 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
+  case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
+  case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
+  case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
+  case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::v2f64;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = Align(16);
+    return true;
+  }
+
   case Intrinsic::nvvm_atomic_load_inc_32:
   case Intrinsic::nvvm_atomic_load_dec_32:
 

diff  --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 5622d5a6fdac5..ab93bf16d4919 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -144,6 +144,7 @@ def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">;
 def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">;
 def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">;
 def hasPTX64 : Predicate<"Subtarget->getPTXVersion() >= 64">;
+def hasPTX65 : Predicate<"Subtarget->getPTXVersion() >= 65">;
 def hasPTX70 : Predicate<"Subtarget->getPTXVersion() >= 70">;
 
 def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;

diff  --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 1aaa9f0dd127d..798538410b104 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1943,21 +1943,21 @@ multiclass VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
                      !strconcat("ldu.global.", TyStr), []>;
 }
 
-multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> { 
+multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
  def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
-                            regclass:$dst4), (ins Int32Regs:$src), 
+                            regclass:$dst4), (ins Int32Regs:$src),
                !strconcat("ldu.global.", TyStr), []>;
  def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
-                            regclass:$dst4), (ins Int64Regs:$src), 
+                            regclass:$dst4), (ins Int64Regs:$src),
                !strconcat("ldu.global.", TyStr), []>;
  def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
-                            regclass:$dst4), (ins MEMri:$src), 
+                            regclass:$dst4), (ins MEMri:$src),
                !strconcat("ldu.global.", TyStr), []>;
  def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
-                            regclass:$dst4), (ins MEMri64:$src), 
+                            regclass:$dst4), (ins MEMri64:$src),
                !strconcat("ldu.global.", TyStr), []>;
  def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
-                            regclass:$dst4), (ins imemAny:$src), 
+                            regclass:$dst4), (ins imemAny:$src),
                !strconcat("ldu.global.", TyStr), []>;
 }
 
@@ -1997,7 +1997,7 @@ defm INT_PTX_LDU_G_v4f32_ELE
 
 
 //-----------------------------------
-// Support for ldg on sm_35 or later 
+// Support for ldg on sm_35 or later
 //-----------------------------------
 
 // Don't annotate ld.global.nc as mayLoad, because these loads go through the
@@ -2045,7 +2045,7 @@ defm INT_PTX_LDG_GLOBAL_p64
 
 // vector
 
-// Elementized vector ldg 
+// Elementized vector ldg
 multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
  def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
                      (ins Int32Regs:$src),
@@ -2064,21 +2064,21 @@ multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
                      !strconcat("ld.global.nc.", TyStr), []>;
 }
 
-multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> { 
+multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
   def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
-                              regclass:$dst4), (ins Int32Regs:$src), 
+                              regclass:$dst4), (ins Int32Regs:$src),
                !strconcat("ld.global.nc.", TyStr), []>;
   def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
-                               regclass:$dst4), (ins Int64Regs:$src), 
+                               regclass:$dst4), (ins Int64Regs:$src),
                !strconcat("ld.global.nc.", TyStr), []>;
   def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
-                              regclass:$dst4), (ins MEMri:$src), 
+                              regclass:$dst4), (ins MEMri:$src),
                !strconcat("ld.global.nc.", TyStr), []>;
   def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
-                              regclass:$dst4), (ins MEMri64:$src), 
+                              regclass:$dst4), (ins MEMri64:$src),
                !strconcat("ld.global.nc.", TyStr), []>;
   def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
-                             regclass:$dst4), (ins imemAny:$src), 
+                             regclass:$dst4), (ins imemAny:$src),
                !strconcat("ld.global.nc.", TyStr), []>;
 }
 
@@ -7568,12 +7568,15 @@ def INT_PTX_SREG_WARPSIZE :
 // In addition to target-independent fields provided by WMMA_REGS, it adds
 // the fields commonly used to implement specific PTX instruction -- register
 // types and names, constraints, parts of assembly, etc.
-class WMMA_REGINFO<WMMA_REGS r>
+class WMMA_REGINFO<WMMA_REGS r, string op>
       : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type> {
   // NVPTX register types used to carry fragment data.
   NVPTXRegClass regclass = !cond(
     !eq(ptx_elt_type, "f16") : Float16x2Regs,
     !eq(ptx_elt_type, "f32") : Float32Regs,
+    !eq(ptx_elt_type, "f64") : Float64Regs,
+    !eq(ptx_elt_type, "bf16") : Int32Regs,
+    !eq(ptx_elt_type, "tf32") : Int32Regs,
     !eq(ptx_elt_type, "s32") : Int32Regs,
     !eq(ptx_elt_type, "s8") : Int32Regs,
     !eq(ptx_elt_type, "u8") : Int32Regs,
@@ -7602,6 +7605,9 @@ class WMMA_REGINFO<WMMA_REGS r>
          !or(!eq(ptx_elt_type, "f16"),
              !eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX60],
 
+    !and(!eq(geom,"m8n8k4"),
+         !eq(ptx_elt_type, "f64")) : [hasSM80, hasPTX70],
+
     // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
     !and(!or(!eq(geom, "m8n32k16"),
              !eq(geom, "m32n8k16")),
@@ -7616,11 +7622,46 @@ class WMMA_REGINFO<WMMA_REGS r>
              !eq(ptx_elt_type, "s8"),
              !eq(ptx_elt_type, "s32"))) : [hasSM72, hasPTX63],
 
-    // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
-    !or(!eq(geom,"m8n8k128"),
-        !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63],
+    !and(!or(!eq(geom,"m16n16k16"),
+             !eq(geom,"m8n32k16"),
+             !eq(geom,"m32n8k16")),
+         !eq(ptx_elt_type, "bf16")) : [hasSM80, hasPTX70],
+
+    !and(!eq(geom,"m16n16k8"),
+         !eq(ptx_elt_type, "tf32")) : [hasSM80, hasPTX70],
+
+    !and(!eq(geom,"m16n16k8"),
+         !eq(ptx_elt_type, "f32")) : [hasSM80, hasPTX70],
+
+    // b1 -> s32 @ m8n8k128(b1)
+    !and(!ne(op,"mma"),
+         !eq(geom,"m8n8k128")) : [hasSM75, hasPTX63],
+
+    // u4/s4 -> s32 @ m8n8k32 (u4/s4)
+    !and(!ne(op,"mma"),
+         !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63],
+
+    !or(!eq(geom,"m16n8k8"),
+        !eq(geom,"m8n8k16")) : [hasSM75, hasPTX65],
 
-    !eq(geom, "m8n8k4") : [hasSM70, hasPTX64]);
+    !and(!ne(ptx_elt_type,"f64"),
+         !eq(geom, "m8n8k4")) : [hasSM70, hasPTX64],
+
+    // mma m8n8k32 requires higher PTX version
+    !and(!eq(op,"mma"),
+         !eq(geom,"m8n8k32")) : [hasSM75, hasPTX65],
+
+    !and(!eq(ptx_elt_type,"f64"),
+         !eq(geom, "m8n8k4")) : [hasSM80, hasPTX70],
+
+    !and(!eq(op,"mma"),
+         !or(!eq(geom, "m16n8k16"),
+             !eq(geom, "m16n8k4"),
+             !eq(geom, "m16n8k32"),
+             !eq(geom, "m16n8k64"),
+             !eq(geom, "m8n8k128"),
+             !eq(geom, "m16n8k128"),
+             !eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7744,11 +7785,11 @@ defset list<WMMA_INSTR> MMA_LDSTs  = {
       foreach space = [".global", ".shared", ""] in {
         foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
           foreach frag = NVVM_MMA_OPS.all_ld_ops in
-            if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
-              def : WMMA_LOAD<WMMA_REGINFO<frag>, layout, space, stride, addr>;
+            if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
+              def : WMMA_LOAD<WMMA_REGINFO<frag, "load">, layout, space, stride, addr>;
           foreach frag = NVVM_MMA_OPS.all_st_ops in
-            if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
-              def : WMMA_STORE_D<WMMA_REGINFO<frag>, layout, space, stride, addr>;
+            if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
+              def : WMMA_STORE_D<WMMA_REGINFO<frag, "store">, layout, space, stride, addr>;
         } // addr
       } // space
     } // stride
@@ -7758,46 +7799,84 @@ defset list<WMMA_INSTR> MMA_LDSTs  = {
 // WMMA.MMA
 class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                WMMA_REGINFO FragC, WMMA_REGINFO FragD,
-               string ALayout, string BLayout, int Satfinite>
-  : WMMA_INSTR<WMMA_NAME_MMA<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
-                             [FragA.Ins, FragB.Ins, FragC.Ins]>,
+               string ALayout, string BLayout, int Satfinite, string rnd>
+  : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, FragA, FragB, FragC, FragD>.record,
+                         [FragA.Ins, FragB.Ins, FragC.Ins]>,
     // Requires does not seem to have effect on Instruction w/o Patterns.
     // We set it here anyways and propagate to the Pat<> we construct below.
     Requires<FragA.Predicates> {
   let OutOperandList = FragD.Outs;
   let InOperandList  = !con(Args, (ins MmaCode:$ptx));
   string TypeList = !cond(
-    !eq(FragD.geom, "m8n8k4") : "." # FragD.ptx_elt_type
-                                # ".f16.f16."
-                                # FragC.ptx_elt_type,
-    !eq(FragD.ptx_elt_type, "s32") : ".s32"
-                                     # "." # FragA.ptx_elt_type
-                                     # "." # FragB.ptx_elt_type
-                                     # ".s32",
-    1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type,
+    !eq(FragA.ptx_elt_type, "f16") : "." # FragD.ptx_elt_type
+                                     # "." # FragC.ptx_elt_type,
+    1: "." # FragD.ptx_elt_type
+       # "." # FragA.ptx_elt_type
+       # "." # FragB.ptx_elt_type
+       # "." # FragC.ptx_elt_type,
   );
-  let AsmString = !if(!eq(FragA.geom, "m8n8k4"),
-     "mma.sync.aligned.m8n8k4"
-        # "." # ALayout
-        # "." # BLayout
-        # TypeList # "\n\t\t"
-        # FragD.regstring # ",\n\t\t"
-        # FragA.regstring # ",\n\t\t"
-        # FragB.regstring # ",\n\t\t"
-        # FragC.regstring # ";",
-     "wmma.mma"
-        # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
-        # ".sync"
-        # "${ptx:aligned}"
-        # "." # ALayout
-        # "." # BLayout
-        # "." # FragA.geom
-        # TypeList
-        # !if(Satfinite, ".satfinite", "") # "\n\t\t"
-        # FragD.regstring # ",\n\t\t"
-        # FragA.regstring # ",\n\t\t"
-        # FragB.regstring # ",\n\t\t"
-        # FragC.regstring # ";");
+  let AsmString = "wmma.mma"
+                  # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
+                  # ".sync"
+                  # "${ptx:aligned}"
+                  # "." # ALayout
+                  # "." # BLayout
+                  # "." # FragA.geom
+                  # !if(!ne(rnd, ""), !strconcat(".", rnd), "")
+                  # TypeList
+                  # !if(Satfinite, ".satfinite", "") # "\n\t\t"
+                  # FragD.regstring # ",\n\t\t"
+                  # FragA.regstring # ",\n\t\t"
+                  # FragB.regstring # ",\n\t\t"
+                  # FragC.regstring # ";";
+}
+
+defset list<WMMA_INSTR> WMMAs  = {
+  foreach layout_a = ["row", "col"] in {
+    foreach layout_b = ["row", "col"] in {
+      foreach satf = [0, 1] in {
+        foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
+          foreach op = NVVM_MMA_OPS.all_wmma_ops in {
+            if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
+              def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">,
+                            WMMA_REGINFO<op[1], "wmma.mma">,
+                            WMMA_REGINFO<op[2], "wmma.mma">,
+                            WMMA_REGINFO<op[3], "wmma.mma">,
+                            layout_a, layout_b, satf, rnd>;
+            }
+          } // op
+        } // rnd
+      } // satf
+    } // layout_b
+  } // layout_a
+} // defset
+
+// MMA
+class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+               WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+               string ALayout, string BLayout, int Satfinite>
+  : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
+                        [FragA.Ins, FragB.Ins, FragC.Ins]>,
+    // Requires does not seem to have effect on Instruction w/o Patterns.
+    // We set it here anyways and propagate to the Pat<> we construct below.
+    Requires<FragA.Predicates> {
+  let OutOperandList = FragD.Outs;
+  let InOperandList  = !con(Args, (ins MmaCode:$ptx));
+  string TypeList = "." # FragD.ptx_elt_type
+                    # "." # FragA.ptx_elt_type
+                    # "." # FragB.ptx_elt_type
+                    # "." # FragC.ptx_elt_type;
+  let AsmString = "mma.sync.aligned."
+                  # FragA.geom
+                  # "." # ALayout
+                  # "." # BLayout
+                  # !if(Satfinite, ".satfinite", "")
+                  # TypeList
+                  # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") # "\n\t\t"
+                  # FragD.regstring # ",\n\t\t"
+                  # FragA.regstring # ",\n\t\t"
+                  # FragB.regstring # ",\n\t\t"
+                  # FragC.regstring # ";";
 }
 
 defset list<WMMA_INSTR> MMAs  = {
@@ -7806,11 +7885,11 @@ defset list<WMMA_INSTR> MMAs  = {
       foreach satf = [0, 1] in {
         foreach op = NVVM_MMA_OPS.all_mma_ops in {
           if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
-            def : WMMA_MMA<WMMA_REGINFO<op[0]>,
-                           WMMA_REGINFO<op[1]>,
-                           WMMA_REGINFO<op[2]>,
-                           WMMA_REGINFO<op[3]>,
-                           layout_a, layout_b, satf>;
+            def : MMA<WMMA_REGINFO<op[0], "mma">,
+                      WMMA_REGINFO<op[1], "mma">,
+                      WMMA_REGINFO<op[2], "mma">,
+                      WMMA_REGINFO<op[3], "mma">,
+                      layout_a, layout_b, satf>;
           }
         } // op
       } // satf
@@ -7822,12 +7901,12 @@ defset list<WMMA_INSTR> MMAs  = {
 // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
 // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
 // the instruction record.
-class WMMA_PAT<WMMA_INSTR wi>
+class MMA_PAT<WMMA_INSTR wi>
       : Pat<wi.IntrinsicPattern,
             !con(!foreach(tmp, wi.Args, !subst(ins, wi, tmp)),
                  (wi ptx.version))>,
         Requires<wi.Predicates>;
 
 // Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, MMA_LDSTs) in
-  def : WMMA_PAT<mma>;
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs) in
+  def : MMA_PAT<mma>;

diff  --git a/llvm/test/CodeGen/NVPTX/lit.local.cfg b/llvm/test/CodeGen/NVPTX/lit.local.cfg
index 2cb98eb371b21..8354800109ebd 100644
--- a/llvm/test/CodeGen/NVPTX/lit.local.cfg
+++ b/llvm/test/CodeGen/NVPTX/lit.local.cfg
@@ -1,2 +1,3 @@
 if not 'NVPTX' in config.root.targets:
     config.unsupported = True
+config.suffixes.add('.py')

diff  --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 8c140c4d93108..8a808bd377eb9 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -6,7 +6,7 @@
 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
 # RUN:           --check-prefixes=INTRINSICS,M16N16
 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
-# RUN:           --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA
+# RUN:           --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
 # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
 # RUN:           | FileCheck %t-ptx60-sm_70.ll
 
@@ -15,7 +15,7 @@
 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
 # RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM
 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
-# RUN:           --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA
+# RUN:           --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
 # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
 # RUN:           | FileCheck %t-ptx61-sm_70.ll
 
@@ -24,7 +24,7 @@
 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
 # RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT
 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
-# RUN:           --check-prefixes=INTRINSICS,NOSUBINT,NOMMA
+# RUN:           --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
 # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
 # RUN:           | FileCheck %t-ptx63-sm_72.ll
 
@@ -33,7 +33,7 @@
 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
 # RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT
 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
-# RUN:           --check-prefixes=INTRINSICS,NOMMA
+# RUN:           --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT
 # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
 # RUN:           | FileCheck %t-ptx63-sm_75.ll
 
@@ -42,10 +42,28 @@
 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
 # RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA
 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
-# RUN:           --check-prefixes=INTRINSICS,NOINT,NOSUBINT
+# RUN:           --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT
 # RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
 # RUN:           | FileCheck %t-ptx64-sm_70.ll
 
+# Check all variants of instructions supported by PTX65 on SM75+
+# RUN: python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll
+# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
+# RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA
+# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
+# RUN:           --check-prefixes=INTRINSICS
+# RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
+# RUN:           | FileCheck %t-ptx65-sm_75.ll
+
+# Check all variants of instructions supported by PTX70 on SM80+
+# RUN: python %s --ptx=70 --gpu-arch=80 > %t-ptx70-sm_80.ll
+# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
+# RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX70MMA
+# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
+# RUN:           --check-prefixes=INTRINSICS
+# RUN: llc < %t-ptx70-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 \
+# RUN:           | FileCheck %t-ptx70-sm_80.ll
+
 from __future__ import print_function
 
 import argparse
@@ -56,19 +74,23 @@ class MMAType:
   def __init__(self, ptx_type):
     self.ptx_type = ptx_type
     self.llvm_type = {
-        "f16" : "<2 x half>",
-        "f32" : "float",
-        "s32" : "i32",
-        "s8"  : "i32",
-        "u8"  : "i32",
-        "s4"  : "i32",
-        "u4"  : "i32",
-        "b1"  : "i32",
+        "f16"  : "<2 x half>",
+        "f32"  : "float",
+        "f64"  : "double",
+        "s32"  : "i32",
+        "s8"   : "i32",
+        "u8"   : "i32",
+        "s4"   : "i32",
+        "u4"   : "i32",
+        "b1"   : "i32",
+        "bf16" : "i32",
+        "tf32" : "i32",
     }[ptx_type];
 
     self.ptx_reg_pattern = {
         "f16" : "%hh[0-9]+",
         "f32" : "%f[0-9]+",
+        "f64" : "%fd[0-9]+",
     }.get(ptx_type, "%r[0-9]+")
 
   def __repr__(self):
@@ -78,16 +100,8 @@ class MMAFrag:
   def __init__(self, geom, frag, ptx_elt_type):
     self.geom = geom
     self.frag = frag
-    self.is_mma = True if geom == "m8n8k4" else False;
     self.mma_type = MMAType(ptx_elt_type);
     self.nregs = {
-        "a:f16" : 2 if self.is_mma else 8,
-        "b:f16" : 2 if self.is_mma else 8,
-        "c:f16" : 4,
-        "d:f16" : 4,
-        "c:f32" : 8,
-        "d:f32" : 8,
-    }.get("%s:%s" % (frag, ptx_elt_type), {
         # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
         "m16n16k16:a:u8" : 2,
         "m16n16k16:a:s8" : 2,
@@ -110,18 +124,123 @@ def __init__(self, geom, frag, ptx_elt_type):
         "m32n8k16:c:s32" : 8,
         "m32n8k16:d:s32" : 8,
 
-        # u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
-        "m8n8k128:a:b1" : 1,
+        "m8n8k16:a:u8": 1,
+        "m8n8k16:a:s8": 1,
+        "m8n8k16:b:u8": 1,
+        "m8n8k16:b:s8": 1,
+        "m8n8k16:c:s32": 2,
+        "m8n8k16:d:s32": 2,
+
+        "m16n8k16:a:u8": 2,
+        "m16n8k16:a:s8": 2,
+        "m16n8k16:b:u8": 1,
+        "m16n8k16:b:s8": 1,
+        "m16n8k16:c:s32": 4,
+        "m16n8k16:d:s32": 4,
+
+        "m16n8k32:a:u8": 4,
+        "m16n8k32:a:s8": 4,
+        "m16n8k32:b:u8": 2,
+        "m16n8k32:b:s8": 2,
+        "m16n8k32:c:s32": 4,
+        "m16n8k32:d:s32": 4,
+
+        # u4/s4 -> s32 @ m8n8k32 (u4/s4)
         "m8n8k32:a:u4" : 1,
         "m8n8k32:a:s4" : 1,
-        "m8n8k128:b:b1" : 1,
         "m8n8k32:b:u4" : 1,
         "m8n8k32:b:s4" : 1,
-        "m8n8k128:c:s32" : 2,
-        "m8n8k128:d:s32" : 2,
         "m8n8k32:c:s32" : 2,
         "m8n8k32:d:s32" : 2,
-    }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), None));
+
+        "m16n8k32:a:u4" : 2,
+        "m16n8k32:a:s4" : 2,
+        "m16n8k32:b:u4" : 1,
+        "m16n8k32:b:s4" : 1,
+        "m16n8k32:c:s32" : 4,
+        "m16n8k32:d:s32" : 4,
+
+        "m16n8k64:a:u4" : 4,
+        "m16n8k64:a:s4" : 4,
+        "m16n8k64:b:u4" : 2,
+        "m16n8k64:b:s4" : 2,
+        "m16n8k64:c:s32" : 4,
+        "m16n8k64:d:s32" : 4,
+
+        # b1 -> s32 @ m8n8k128(b1)
+        "m8n8k128:a:b1" : 1,
+        "m8n8k128:b:b1" : 1,
+        "m8n8k128:c:s32" : 2,
+        "m8n8k128:d:s32" : 2,
+
+        "m16n8k128:a:b1" : 2,
+        "m16n8k128:b:b1" : 1,
+        "m16n8k128:c:s32" : 4,
+        "m16n8k128:d:s32" : 4,
+
+        "m16n8k256:a:b1" : 4,
+        "m16n8k256:b:b1" : 2,
+        "m16n8k256:c:s32" : 4,
+        "m16n8k256:d:s32" : 4,
+
+        # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+        "m16n16k16:a:bf16" : 4,
+        "m16n16k16:b:bf16" : 4,
+        "m8n32k16:a:bf16" : 2,
+        "m8n32k16:b:bf16" : 8,
+        "m32n8k16:a:bf16" : 8,
+        "m32n8k16:b:bf16" : 2,
+
+        "m16n8k16:a:bf16" : 4,
+        "m16n8k16:b:bf16" : 2,
+        "m16n8k16:c:f32" : 4,
+        "m16n8k16:d:f32" : 4,
+        "m16n8k8:a:bf16" : 2,
+        "m16n8k8:b:bf16" : 1,
+        "m16n8k8:c:f32" : 4,
+        "m16n8k8:d:f32" : 4,
+
+        "m8n8k4:a:f64" : 1,
+        "m8n8k4:b:f64" : 1,
+        "m8n8k4:c:f64" : 2,
+        "m8n8k4:d:f64" : 2,
+
+        # tf32 -> s32 @ m16n16k8
+        "m16n16k8:a:tf32" : 4,
+        "m16n16k8:b:tf32" : 4,
+
+        "m16n8k4:a:tf32" : 2,
+        "m16n8k4:b:tf32" : 1,
+        "m16n8k4:c:f32" : 4,
+        "m16n8k4:d:f32" : 4,
+        "m16n8k8:a:tf32" : 4,
+        "m16n8k8:b:tf32" : 2,
+        "m16n8k8:c:f32" : 4,
+        "m16n8k8:d:f32" : 4,
+
+        "m8n8k4:a:f16": 2,
+        "m8n8k4:b:f16": 2,
+        "m16n8k8:a:f16": 2,
+        "m16n8k8:b:f16": 1,
+        "m16n8k8:c:f16": 2,
+        "m16n8k8:d:f16": 2,
+        "m16n8k8:c:f32": 4,
+        "m16n8k8:d:f32": 4,
+        "m16n8k16:a:f16": 4,
+        "m16n8k16:b:f16": 2,
+        "m16n8k16:c:f16": 2,
+        "m16n8k16:d:f16": 2,
+        "m16n8k16:c:f32": 4,
+        "m16n8k16:d:f32": 4,
+    }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), {
+        # All other FP shape/fragment/type combinations have the same size
+        "a:f16" : 8,
+        "b:f16" : 8,
+        "c:f16" : 4,
+        "d:f16" : 4,
+        "c:f32" : 8,
+        "d:f32" : 8,
+    }.get("%s:%s" % (frag, ptx_elt_type), None))
     assert(self.nregs);
 
   def __repr__(self):
@@ -153,9 +272,13 @@ def make_ldst_ops(geoms, frags, types):
   return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
           in product(geoms, frags, types)]
 
-def get_mma_ops():
-  return (make_mma_ops(["m8n8k4"],
-                       ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
+def get_wmma_ops():
+  return (make_mma_ops(["m16n16k8"],
+                       ["tf32"], [], ["f32"], []) +
+          make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+                       ["bf16"], [], ["f32"], []) +
+          make_mma_ops(["m8n8k4"],
+                       ["f64"], [], ["f64"], []) +
           make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
                        ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
           make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
@@ -164,20 +287,38 @@ def get_mma_ops():
                        ["s4", "u4"], [], ["s32"], []) +
           make_mma_ops(["m8n8k128"],
                        ["b1"], [], ["s32"], []))
+
+def get_mma_ops():
+  return (make_mma_ops(["m8n8k4"],
+                       ["f64"], [], ["f64"], []) +
+          make_mma_ops(["m16n8k4", "m16n8k8"],
+                       ["tf32"], [], ["f32"], []) +
+          make_mma_ops(["m16n8k16", "m16n8k8"],
+                       ["bf16"], [], ["f32"], []) +
+          make_mma_ops(["m8n8k4", "m16n8k8", "m16n8k16"],
+                       ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
+          make_mma_ops(["m8n8k16", "m16n8k16", "m16n8k32"],
+                       ["s8", "u8"], ["s8", "u8"], ["s32"], []) +
+          make_mma_ops(["m8n8k32", "m16n8k32", "m16n8k64"],
+                       ["s4", "u4"], ["s4", "u4"], ["s32"], []) +
+          make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"],
+                       ["b1"], [], ["s32"], []))
+
 def get_ldst_ops(kind):
   ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
-                            ["a", "b"], ["f16", "u8", "s8"]) +
+                            ["a", "b"], ["f16", "u8", "s8", "bf16"]) +
               make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
                             ["c", "d"], ["f16", "f32", "s32"]) +
               make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
               make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
-              make_ldst_ops(["m8n8k32", "m8n8k128"],  ["c", "d"], ["s32"]))
+              make_ldst_ops(["m8n8k32", "m8n8k128"],  ["c", "d"], ["s32"]) +
+              make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) +
+              make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) +
+              make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
   return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
 
-def is_geom_supported(geom):
+def is_wmma_geom_supported(geom):
   # geometries for FP and ints.
-  if geom == "m8n8k4":
-    return ptx_version >= 64
   if geom in ["m8n32k16", "m32n8k16"]:
     return ptx_version >= 61
   # geometries for sub-ints.
@@ -185,6 +326,21 @@ def is_geom_supported(geom):
     return ptx_version >= 63 and gpu_arch >= 75
   if geom == "m16n16k16":
     return ptx_version >= 60
+  if geom == "m16n8k8":
+    return ptx_version >= 65
+  if geom in ["m16n16k8", "m8n8k4"]:
+    return ptx_version >= 70
+  assert(False) # Unexpected geometry.
+
+def is_mma_geom_supported(geom):
+  # geometries for FP and ints.
+  if geom == "m8n8k4":
+    return ptx_version >= 64
+  if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]:
+    return ptx_version >= 65
+  if geom in ["m16n8k16", "m16n8k4", "m16n8k32", "m16n8k64", "m8n8k128",
+              "m16n8k128", "m16n8k256"]:
+    return ptx_version >= 70
   assert(False) # Unexpected geometry.
 
 def is_type_supported(ptx_type):
@@ -192,30 +348,63 @@ def is_type_supported(ptx_type):
     return ptx_version >= 63 and gpu_arch >= 72
   if ptx_type in ["s4", "u4", "b1"]:
     return ptx_version >= 63 and gpu_arch >= 75
+  if ptx_type in ["bf16", "tf32", "f64"]:
+    return ptx_version >= 70
   return ptx_version >= 60 and gpu_arch >= 70
 
+def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
+  if not (is_type_supported(op.a.mma_type.ptx_type)
+          and is_wmma_geom_supported(op.a.geom)):
+    return False
+
+  # rnd is only supported for FP64 WMMA
+  if rnd and op.a.mma_type.ptx_type != "f64":
+    return False
+
+  if satf:
+    # satfinite for floating points was removed in PTX 6.5
+    if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65:
+      return False
+    if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
+      return False
+
+  # sub-integer require row/col layout.
+  if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
+    return layout_a == "row" and layout_b == "col"
+  return True
 
 def is_mma_variant_supported(op, layout_a, layout_b, satf):
   if not (is_type_supported(op.a.mma_type.ptx_type)
-          and is_geom_supported(op.a.geom)):
+          and is_mma_geom_supported(op.a.geom)):
+    return False
+
+  if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
+    return False
+
+  # If the type of C is f32 then so must the type of D
+  if (op.a.geom == "m8n8k4" and op.c.mma_type.ptx_type == "f32"
+      and op.d.mma_type.ptx_type != "f32"):
     return False
-  if op.a.geom == "m8n8k4":
-    if satf:
+
+  # A and B type must be the same. C and D type must be the same
+  if (op.a.geom == "m16n8k8"
+        and (op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
+             or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type)):
       return False
-    if op.c.mma_type.ptx_type == "f32":
-      # If C is f32, D must be, too.
-      return op.d.mma_type.ptx_type == "f32"
 
-  # sub-integer require row/col layout, and no satf.
-  if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
-    if op.a.mma_type.ptx_type == "b1" and satf:
+  # C and D type must be the same
+  if (op.a.geom == "m16n8k16"
+      and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type):
       return False
+
+  # Require row/col layout for all MMA except m8n8k4 on FP16
+  if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
     return layout_a == "row" and layout_b == "col"
   return True
 
 def is_ldst_variant_supported(frag, layout):
   if not (is_type_supported(frag.mma_type.ptx_type)
-          and is_geom_supported(frag.geom)):
+          and is_wmma_geom_supported(frag.geom)):
     return False
   if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
     # sub-integer require sm_75 and ptx63, row/col layout for a/b.
@@ -396,24 +585,37 @@ def gen_wmma_store_tests():
   return generated_items
 
 def mma_signature(op):
-  if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
-    # int and sub-int ops are identified by input type.
-    return op.a.mma_type.ptx_type
-  else:
-    # the rest are FP ops identified by accumulator & result type.
+  if op.a.mma_type.ptx_type == "f16":
+    # FP16 ops identified by accumulator & result type.
     return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
+  elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
+    # other ops are identified by input types.
+    return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
+  else:
+    # if input types are the same, it only appears once.
+    return op.a.mma_type.ptx_type
 
 def mma_ptx_signature(op):
-  if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
-    # int and sub-int instructions encode all four types as D.A.B.C
-    return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
-  if op.a.geom == "m8n8k4":
-    return "%s.f16.f16.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
+  # Encode all four types as D.A.B.C
+  return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
+
+def wmma_signature(op):
+  if op.a.mma_type.ptx_type == "f16":
+    # FP16 ops identified by accumulator & result type.
+    return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
   else:
-    # the rest are FP instructions use D.C
+    # other ops are identified by input type.
+    return op.a.mma_type.ptx_type
+
+def wmma_ptx_signature(op):
+  if op.a.mma_type.ptx_type == "f16":
+    # FP16 instructions use D.C
     return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
+  else:
+    # other instructions encode all four types as D.A.B.C
+    return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
 
-def gen_wmma_mma_tests():
+def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
   mma_template = """
 declare ${ret_ty} @${intrinsic}(
         ${args});
@@ -431,10 +633,61 @@ def gen_wmma_mma_tests():
   ret ${ret_ty} %r;
 }
 """
-  wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
-  wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}"
-  mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}.${intrinsic_signature}"
-  mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}.${ptx_signature}"
+
+  test_params = params
+  test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+  test_params["function"] = test_params["intrinsic"].replace(".", "_")
+  test_params["instruction"] = Template(instruction_template).substitute(params)
+  test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+  test_params["check_a"] = check_pattern(op.a)
+  test_params["check_b"] = check_pattern(op.b)
+  test_params["check_c"] = check_pattern(op.c)
+  test_params["check_d"] = check_pattern(op.d)
+  args = ",\n        ".join(make_wmma_slice_args(frag)
+                            for frag in (op.a, op.b, op.c))
+  test_params["args"] = args
+  print(Template(mma_template).substitute(test_params))
+  return (test_params["intrinsic"], test_params["instruction"])
+
+def gen_wmma_mma_tests():
+  wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
+  wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
+
+  generated_items=[]
+
+  for op, alayout, blayout, rnd, satf in product(
+      get_wmma_ops(),
+      ["row","col"],
+      ["row","col"],
+      [".rn", ".rz", ".rm", ".rp", ""],
+      [".satfinite", ""]):
+
+    if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
+      continue
+
+    params = {
+        "aligned" : ".aligned" if ptx_version >= 63 else "",
+        "alayout" : alayout,
+        "blayout" : blayout,
+        "intrinsic_signature" : wmma_signature(op),
+        "ptx_signature" : wmma_ptx_signature(op),
+        "satf"  : satf,
+        "rnd"   : rnd,
+        "geom"  : op.a.geom,
+        "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
+    }
+
+    intrinsic_template = wmma_intrinsic_template
+    instruction_template = wmma_instruction_template
+
+    generated_items.append(common_mma_test_gen(params, op,
+      intrinsic_template, instruction_template))
+
+  return generated_items
+
+def gen_mma_tests():
+  mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
+  mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${mma_variant}"
 
   generated_items=[]
 
@@ -458,28 +711,11 @@ def gen_wmma_mma_tests():
         "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
     }
 
-    if op.a.geom == "m8n8k4":
-      intrinsic_template = mma_intrinsic_template
-      instruction_template = mma_instruction_template
-    else:
-      intrinsic_template = wmma_intrinsic_template
-      instruction_template = wmma_instruction_template
+    intrinsic_template = mma_intrinsic_template
+    instruction_template = mma_instruction_template
 
-    test_params = params
-    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
-    test_params["function"] = test_params["intrinsic"].replace(".", "_")
-    test_params["instruction"] = Template(instruction_template).substitute(params)
-    test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
-    test_params["check_a"] = check_pattern(op.a)
-    test_params["check_b"] = check_pattern(op.b)
-    test_params["check_c"] = check_pattern(op.c)
-    test_params["check_d"] = check_pattern(op.d)
-    args = ",\n        ".join(make_wmma_slice_args(frag)
-                              for frag in (op.a, op.b, op.c))
-    test_params["args"] = args
-    print(Template(mma_template).substitute(test_params))
-    generated_items.append((test_params["intrinsic"],
-                            test_params["instruction"]))
+    generated_items.append(common_mma_test_gen(params, op,
+      intrinsic_template, instruction_template))
 
   return generated_items
 
@@ -497,6 +733,8 @@ def gen_check_unsupported_ops(items):
 ; NOINT-NOT: .{{s32|s8}}
 ; NOSUBINT-NOT: {{s4|u4|b1}}
 ; NOMMA-NOT: .m8n8k4.
+; NOALTFLOAT-NOT: .{{bf16|tf32}}
+; NODOUBLE-NOT: .f64
 
 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -543,10 +781,61 @@ def gen_check_unsupported_ops(items):
 ; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
 ; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
 
+; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
+; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
+; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
+; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
+; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
+; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
+; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
+; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
+
+; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
+; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
+; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
+
 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
 ; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
 ; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
+
+; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
+; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
+; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
+; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
+; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
+; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
+; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
+; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
+; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
+; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
+
+; PTX70MMA-DAG: mma.m8n8k4.row.col.f64
+; PTX70MMA-DAG: mma.m16n8k4.row.col.tf32
+; PTX70MMA-DAG: mma.m16n8k8.row.col.tf32
+; PTX70MMA-DAG: mma.m16n8k16.row.col.bf16
+; PTX70MMA-DAG: mma.m16n8k8.row.col.bf16
+; PTX70MMA-DAG: mma.m16n8k16.row.col.f16.f16
+; PTX70MMA-DAG: mma.m16n8k16.row.col.f32.f32
+; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
+; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
+; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
+; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
+; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
+; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
+; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
+; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
+; PTX70MMA-DAG: mma.m8n8k128.row.col.b1
+; PTX70MMA-DAG: mma.m16n8k128.row.col.b1
+; PTX70MMA-DAG: mma.m16n8k256.row.col.b1
 ;
 
 """)
@@ -561,6 +850,7 @@ def gen_tests():
   items = gen_wmma_load_tests()
   items += gen_wmma_store_tests()
   items += gen_wmma_mma_tests()
+  items += gen_mma_tests()
   gen_check_unsupported_ops(items)
 
 parser = argparse.ArgumentParser()


        


More information about the llvm-commits mailing list