[llvm] d774b4a - [NVPTX, CUDA] Add .and.popc variant of the b1 MMA instruction.

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 15 12:02:53 PDT 2021


Author: Artem Belevich
Date: 2021-07-15T12:02:09-07:00
New Revision: d774b4aa5eac785ffe40009091667521e183df40

URL: https://github.com/llvm/llvm-project/commit/d774b4aa5eac785ffe40009091667521e183df40
DIFF: https://github.com/llvm/llvm-project/commit/d774b4aa5eac785ffe40009091667521e183df40.diff

LOG: [NVPTX, CUDA] Add .and.popc variant of the b1 MMA instruction.

That should allow clang to compile mma.h from CUDA-11.3.

Differential Revision: https://reviews.llvm.org/D105384

Added: 
    

Modified: 
    clang/include/clang/Basic/BuiltinsNVPTX.def
    clang/lib/CodeGen/CGBuiltin.cpp
    clang/test/CodeGen/builtins-nvptx-mma.cu
    clang/test/CodeGen/builtins-nvptx-mma.py
    llvm/include/llvm/IR/IntrinsicsNVVM.td
    llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
    llvm/test/CodeGen/NVPTX/wmma.py

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def
index e815138a15c15..3c96900136a40 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.def
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -724,6 +724,7 @@ TARGET_BUILTIN(__hmma_m8n32k16_mma_f16f32, "vi*iC*iC*fC*IiIi", "", AND(SM_70,PTX
 TARGET_BUILTIN(__bmma_m8n8k128_ld_a_b1, "vi*iC*UiIi", "", AND(SM_75,PTX63))
 TARGET_BUILTIN(__bmma_m8n8k128_ld_b_b1, "vi*iC*UiIi", "", AND(SM_75,PTX63))
 TARGET_BUILTIN(__bmma_m8n8k128_ld_c, "vi*iC*UiIi", "", AND(SM_75,PTX63))
+TARGET_BUILTIN(__bmma_m8n8k128_mma_and_popc_b1, "vi*iC*iC*iC*Ii", "", AND(SM_75,PTX71))
 TARGET_BUILTIN(__bmma_m8n8k128_mma_xor_popc_b1, "vi*iC*iC*iC*Ii", "", AND(SM_75,PTX63))
 TARGET_BUILTIN(__bmma_m8n8k128_st_c_i32, "vi*iC*UiIi", "", AND(SM_75,PTX63))
 TARGET_BUILTIN(__imma_m16n16k16_ld_a_s8, "vi*iC*UiIi", "", AND(SM_72,PTX63))

diff  --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 0635be425e0aa..4091b6cc62ce9 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -16630,9 +16630,18 @@ static NVPTXMmaInfo getNVPTXMmaInfo(unsigned BuiltinID) {
       0, \
       0
 // b1 MMA does not support .satfinite.
-#define MMA_VARIANTS_B1(geom, type) \
+#define MMA_VARIANTS_B1_XOR(geom, type) \
       0, \
-      Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type,             \
+      Intrinsic::nvvm_wmma_##geom##_mma_xor_popc_row_col_##type,             \
+      0, \
+      0, \
+      0, \
+      0, \
+      0, \
+      0
+#define MMA_VARIANTS_B1_AND(geom, type) \
+      0, \
+      Intrinsic::nvvm_wmma_##geom##_mma_and_popc_row_col_##type,             \
       0, \
       0, \
       0, \
@@ -16689,7 +16698,9 @@ static NVPTXMmaInfo getNVPTXMmaInfo(unsigned BuiltinID) {
   case NVPTX::BI__imma_m8n8k32_mma_u4:
     return {1, 1, 2, 2, {{MMA_VARIANTS_I4(m8n8k32, u4)}}};
   case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1:
-    return {1, 1, 2, 2, {{MMA_VARIANTS_B1(m8n8k128, b1)}}};
+    return {1, 1, 2, 2, {{MMA_VARIANTS_B1_XOR(m8n8k128, b1)}}};
+  case NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1:
+    return {1, 1, 2, 2, {{MMA_VARIANTS_B1_AND(m8n8k128, b1)}}};
 
   // Double MMA
   case NVPTX::BI__dmma_m8n8k4_mma_f64:
@@ -16710,7 +16721,8 @@ static NVPTXMmaInfo getNVPTXMmaInfo(unsigned BuiltinID) {
 #undef MMA_VARIANTS
 #undef MMA_SATF_VARIANTS
 #undef MMA_VARIANTS_I4
-#undef MMA_VARIANTS_B1
+#undef MMA_VARIANTS_B1_AND
+#undef MMA_VARIANTS_B1_XOR
 }
 
 } // namespace
@@ -17119,6 +17131,7 @@ CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
   case NVPTX::BI__imma_m8n8k32_mma_s4:
   case NVPTX::BI__imma_m8n8k32_mma_u4:
   case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1:
+  case NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1:
   case NVPTX::BI__dmma_m8n8k4_mma_f64:
   case NVPTX::BI__mma_bf16_m16n16k16_mma_f32:
   case NVPTX::BI__mma_bf16_m8n32k16_mma_f32:
@@ -17136,7 +17149,8 @@ CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
     if (Layout < 0 || Layout > 3)
       return nullptr;
     llvm::APSInt SatfArg;
-    if (BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1)
+    if (BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1 ||
+        BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1)
       SatfArg = 0;  // .b1 does not have satf argument.
     else if (Optional<llvm::APSInt> OptSatfArg =
                  E->getArg(5)->getIntegerConstantExpr(getContext()))

diff  --git a/clang/test/CodeGen/builtins-nvptx-mma.cu b/clang/test/CodeGen/builtins-nvptx-mma.cu
index 7e9bac86792d2..aaa44bcaa7e22 100644
--- a/clang/test/CodeGen/builtins-nvptx-mma.cu
+++ b/clang/test/CodeGen/builtins-nvptx-mma.cu
@@ -3,20 +3,21 @@
 // *** DO NOT EDIT ***
 //
 //  This test has been automatically generated by
-//  builtins-nvtx-mma.py --ptx=70 --gpu-arch=80
+//  builtins-nvtx-mma.py --ptx=71 --gpu-arch=80
 //
-// Make sure we can handle all builtins available on sm_80 with PTX70
+// Make sure we can handle all builtins available on sm_80 with PTX71
 // RUN: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_80 \
-// RUN:            -fcuda-is-device -target-feature +ptx70 \
-// RUN:            -DPTX=70 -DSM=80 \
+// RUN:            -fcuda-is-device -target-feature +ptx71 \
+// RUN:            -DPTX=71 -DSM=80 \
 // RUN:            -S -emit-llvm -o - -x cuda %s \
-// RUN:   | FileCheck -check-prefixes=CHECK_PTX70_SM80,CHECK_PTX60_SM70,CHECK_PTX63_SM72,CHECK_PTX61_SM70,CHECK_PTX63_SM75 %s
+// RUN:   | FileCheck -check-prefixes=CHECK_PTX70_SM80,CHECK_PTX60_SM70,CHECK_PTX63_SM72,CHECK_PTX61_SM70,CHECK_PTX63_SM75,CHECK_PTX71_SM75 %s
 // Verify that all builtins have correct constraints.
 // RUN: %clang_cc1 -triple nvptx-unknown-unknown \
 // RUN:   -target-cpu sm_60 -target-feature +ptx42 \
-// RUN:   -DPTX=70 -DSM=80 -fcuda-is-device -S -o /dev/null -x cuda \
+// RUN:   -DPTX=71 -DSM=80 -fcuda-is-device -S -o /dev/null -x cuda \
 // RUN:   -verify %s
 
+
 #if !defined(CUDA_VERSION)
 #define __device__ __attribute__((device))
 #define __global__ __attribute__((global))
@@ -31,6 +32,7 @@ __device__ void test_wmma_buitins(int *src, int *dst,
                                   float *fsrc, float *fdst,
                                   double *dsrc, double *ddst, int ldm) {
 
+
 #if (PTX >= 60) && (SM >= 70)
 
   // CHECK_PTX60_SM70: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16
@@ -735,7 +737,7 @@ __device__ void test_wmma_buitins(int *src, int *dst,
   // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.store.d.row.stride.s32
   // expected-error-re at +1 {{'__imma_m8n8k32_st_c_i32' needs target feature (sm_75{{.*}},(ptx63{{.*}}}}
   __imma_m8n8k32_st_c_i32(dst, src, ldm, 0);
-  // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.row.col.b1
+  // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.xor.popc.row.col.b1
   // expected-error-re at +1 {{'__bmma_m8n8k128_mma_xor_popc_b1' needs target feature (sm_75{{.*}},(ptx63{{.*}}}}
   __bmma_m8n8k128_mma_xor_popc_b1(dst, src, src, src, 1);
   // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.mma.row.col.s4
@@ -750,7 +752,7 @@ __device__ void test_wmma_buitins(int *src, int *dst,
   // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.mma.row.col.u4.satfinite
   // expected-error-re at +1 {{'__imma_m8n8k32_mma_u4' needs target feature (sm_75{{.*}},(ptx63{{.*}}}}
   __imma_m8n8k32_mma_u4(dst, src, src, src, 1, 1);
-#endif // (PTX >= 63) && (SM >= 75)
+#endif // (PTX >= 63) && (SM >= 75) 
 
 #if (PTX >= 70) && (SM >= 80)
 
@@ -898,5 +900,12 @@ __device__ void test_wmma_buitins(int *src, int *dst,
   // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64
   // expected-error-re at +1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
   __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 0, 0);
-#endif // (PTX >= 70) && (SM >= 80)
+#endif // (PTX >= 70) && (SM >= 80) 
+
+#if (PTX >= 71) && (SM >= 75)
+
+  // CHECK_PTX71_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.and.popc.row.col.b1
+  // expected-error-re at +1 {{'__bmma_m8n8k128_mma_and_popc_b1' needs target feature (sm_75{{.*}},(ptx71{{.*}}}}
+  __bmma_m8n8k128_mma_and_popc_b1(dst, src, src, src, 1);
+#endif // (PTX >= 71) && (SM >= 75) 
 }

diff  --git a/clang/test/CodeGen/builtins-nvptx-mma.py b/clang/test/CodeGen/builtins-nvptx-mma.py
index 2ffc21b12fb06..dc40f04c11ce6 100644
--- a/clang/test/CodeGen/builtins-nvptx-mma.py
+++ b/clang/test/CodeGen/builtins-nvptx-mma.py
@@ -22,24 +22,29 @@ def __repr__(self):
     return "%s:%s:%s" % (self.geom, self.frag, self.ptx_type)
 
 class MMAOp:
-  def __init__(self, a, b, c, d):
+  def __init__(self, a, b, c, d, b1op=""):
     self.a = a
     self.b = b
     self.c = c
     self.d = d
+    self.b1op = b1op
 
   def __repr__(self):
     return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d ))
 
-def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
+def make_mma_ops(geoms, types_a, types_b, types_c, types_d, b1ops=None):
   ops = []
+  if b1ops is None:
+    b1ops = [""]
   for geom, type_a, type_c in product( geoms,  types_a, types_c):
     for type_b, type_d in product(types_b if types_b else [type_a],
                                   types_d if types_d else [type_c]):
-      ops.append(MMAOp(MMAFrag(geom, "a", type_a),
-                       MMAFrag(geom, "b", type_b),
-                       MMAFrag(geom, "c", type_c),
-                       MMAFrag(geom, "d", type_d)))
+      ops += [
+          MMAOp(MMAFrag(geom, "a", type_a),
+                MMAFrag(geom, "b", type_b),
+                MMAFrag(geom, "c", type_c),
+                MMAFrag(geom, "d", type_d), b1op)
+          for b1op in b1ops]
   return ops
 
 def make_ldst_ops(geoms, frags, types):
@@ -60,9 +65,12 @@ def get_mma_ops():
           make_mma_ops(["m8n8k32"],
                        ["s4", "u4"], [], ["s32"], []) +
           make_mma_ops(["m8n8k128"],
-                       ["b1"], [], ["s32"], []))
+                       ["b1"], [], ["s32"], [],
+                       [".xor.popc", ".and.popc"]))
 
 def get_ldst_ops():
+  # NOTE: fragemts are from the point of view of PTX.
+  # fragment `d` is only for store ops, others for both loads and stores.
   return (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
                         ["a", "b"], ["f16", "u8", "s8", "bf16"]) +
           make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
@@ -71,8 +79,11 @@ def get_ldst_ops():
           make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
           make_ldst_ops(["m8n8k32", "m8n8k128"],  ["c", "d"], ["s32"]) +
           make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) +
-          make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) +
-          make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
+          # TF32 m16n16k8 is odd.
+          # For fragment 'C' it uses __mma_*tf32*_m16n16k8_ld_c
+          # but 'D' calls __mma_m16n16k8_st_c_*f32*.
+          make_ldst_ops(["m16n16k8"], ["a", "b", "c"], ["tf32"]) +
+          make_ldst_ops(["m16n16k8"], ["d"], ["f32"]))
 
 def is_geom_supported(geom):
   # geometries for FP and ints.
@@ -180,15 +191,19 @@ def get_mma_builtin_name(op):
   else:
     suffix = op.a.ptx_type
 
-  name = "%s_%s_mma%s_%s" % (prefix, op.a.geom,
-                             "_xor_popc" if op.a.ptx_type == "b1" else "",
-                             suffix)
+  name = "{prefix}_{geom}_mma{b1op}_{suffix}".format(
+      prefix = prefix,
+      geom = op.a.geom,
+      b1op = op.b1op.replace(".","_"),
+      suffix = suffix)
   return name
 
-def get_required_sm(frag):
+def get_required_sm(frag, b1op=""):
   if frag.ptx_type in ["f64", "bf16", "tf32"]:
     return 80
   if frag.ptx_type in ["u4", "s4", "b1"]:
+    if b1op == "_and_popc":
+      return 80
     return 75
   if frag.ptx_type in ["s8", "u8"]:
     return 72
@@ -204,7 +219,9 @@ def get_required_sm(frag):
       return 70
   assert(False)
 
-def get_required_ptx(frag):
+def get_required_ptx(frag, b1op=""):
+  if frag.ptx_type == "b1" and b1op == ".and.popc":
+    return 71
   if frag.ptx_type in ["f64", "bf16", "tf32"]:
     return 70
   if frag.ptx_type in ["f16", "f32"]:
@@ -215,11 +232,13 @@ def get_required_ptx(frag):
     return 61
   return 63
 
-def get_src_dst_prefix(ptx_type):
-  if ptx_type == "f32":
+def get_src_dst_prefix(frag):
+  if frag.ptx_type == "f32":
     return "f"
-  if ptx_type == "f64":
+  if frag.ptx_type == "f64":
     return "d"
+  if frag.ptx_type == "tf32" and frag.frag in ["c", "d"]:
+    return "f"
   return ""
 
 def gen_wmma_ldst_tests(results):
@@ -235,9 +254,17 @@ def gen_wmma_ldst_tests(results):
     if not is_ldst_variant_supported(frag, layout):
       continue
 
-    src_dst_prefix = get_src_dst_prefix(frag.ptx_type)
+    src_dst_prefix = get_src_dst_prefix(frag)
+
     min_sm = get_required_sm(frag)
     min_ptx = get_required_ptx(frag)
+    # TF32 uses f32 for accumulator loads.
+    if frag.geom == "m16n16k8" and frag.frag =="c":
+      assert frag.ptx_type == "tf32"
+      itype = "f32"
+    else:
+      itype = frag.ptx_type
+
     params = {
         "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm),
         "builtin" : get_ldst_builtin_name(frag),
@@ -250,7 +277,7 @@ def gen_wmma_ldst_tests(results):
             "frag" : frag.frag,
             "geom"   : frag.geom,
             "ilayout" : layout,
-            "itype" : frag.ptx_type,
+            "itype" : itype,
             "op" : "store" if frag.frag == "d" else "load",
         })
     }
@@ -283,7 +310,7 @@ def gen_wmma_mma_tests(results):
   // expected-error-re at +1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
   ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_satf});
 """.rstrip()
-  intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
+  intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}.${intrinsic_signature}${satf}"
 
   for op, alayout, blayout, satf in sorted(product( get_mma_ops(),
                                                     ["row","col"],
@@ -294,15 +321,15 @@ def gen_wmma_mma_tests(results):
     if not is_mma_variant_supported(op, alayout, blayout, satf):
       continue
 
-    asrc_prefix = get_src_dst_prefix(op.a.ptx_type)
-    csrc_prefix = get_src_dst_prefix(op.c.ptx_type)
-    ddst_prefix = get_src_dst_prefix(op.d.ptx_type)
-    min_sm = get_required_sm(op.a)
-    min_ptx = get_required_ptx(op.a)
+    asrc_prefix = get_src_dst_prefix(op.a)
+    csrc_prefix = get_src_dst_prefix(op.c)
+    ddst_prefix = get_src_dst_prefix(op.d)
     if op.a.ptx_type == "b1": # .b1 MMA has no satf argument.
        isatf_arg = ""
     else:
        isatf_arg = ", 1" if satf else ", 0"
+    min_sm = get_required_sm(op.a, op.b1op)
+    min_ptx = get_required_ptx(op.a, op.b1op)
     params = {
         "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm),
         "builtin" : get_mma_builtin_name(op),
@@ -319,6 +346,7 @@ def gen_wmma_mma_tests(results):
             "blayout" : blayout,
             "intrinsic_signature" : mma_signature(op),
             "satf"  : satf,
+            "b1op"  : op.b1op
         })
     }
     results[(min_ptx, min_sm)] += Template(mma_template).substitute(params)

diff  --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 3ce9dfb1bb807..cc43d23bec1ce 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -225,12 +225,13 @@ class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
    string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type));
 }
 
-class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd,
+class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, string b1op,
                 WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
   string signature = MMA_SIGNATURE<A, B, C, D>.ret;
   string llvm = "llvm.nvvm.wmma."
                 # A.geom
                 # ".mma"
+                # b1op
                 # "." # ALayout
                 # "." # BLayout
                 # !if(!ne(Rnd, ""), !strconcat(".", Rnd), "")
@@ -241,11 +242,12 @@ class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd,
                   !subst("llvm.", "int_", llvm));
 }
 
-class MMA_NAME<string ALayout, string BLayout, int Satfinite,
+class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
                WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
   string signature = MMA_SIGNATURE<A, B, C, D>.ret;
-  string llvm = "llvm.nvvm.mma."
-                # A.geom
+  string llvm = "llvm.nvvm.mma"
+                # b1op
+                # "." # A.geom
                 # "." # ALayout
                 # "." # BLayout
                 # !if(Satfinite, ".satfinite", "")
@@ -430,6 +432,13 @@ class NVVM_WMMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_
   );
 }
 
+class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
+  list<string> ret = !cond(
+    !eq(frags[0].ptx_elt_type, "b1") : [".xor.popc", ".and.popc"],
+    true: [""]
+  );
+}
+
 // Returns true if this combination of layout/satf for MMA ops is supported;
 // false otherwise.
 // E.g.
@@ -4460,25 +4469,27 @@ foreach layout = ["row", "col"] in {
 }
 
 // WMMA.MMA
-class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite, string rnd,
+class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite, string rnd, string b1op,
                     WMMA_REGS A, WMMA_REGS B,
                     WMMA_REGS C, WMMA_REGS D>
   : Intrinsic<D.regs,
               !listconcat(A.regs, B.regs, C.regs),
               [IntrNoMem],
-              WMMA_NAME<ALayout, BLayout, Satfinite, rnd, A, B, C, D>.llvm>;
+              WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, A, B, C, D>.llvm>;
 
 foreach layout_a = ["row", "col"] in {
   foreach layout_b = ["row", "col"] in {
     foreach satf = [0, 1] in {
       foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
         foreach op = NVVM_MMA_OPS.all_wmma_ops in {
-          if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
-            def WMMA_NAME<layout_a, layout_b, satf, rnd,
-                              op[0], op[1], op[2], op[3]>.record
-              : NVVM_WMMA_MMA<layout_a, layout_b, satf, rnd,
-                              op[0], op[1], op[2], op[3]>;
-          }
+          foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
+            if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
+              def WMMA_NAME<layout_a, layout_b, satf, rnd, b1op,
+                                op[0], op[1], op[2], op[3]>.record
+                : NVVM_WMMA_MMA<layout_a, layout_b, satf, rnd, b1op,
+                                op[0], op[1], op[2], op[3]>;
+            }
+          } // b1op
         } // op
       } // rnd
     } // satf
@@ -4486,21 +4497,23 @@ foreach layout_a = ["row", "col"] in {
 } // layout_a
 
 // MMA
-class NVVM_MMA<string ALayout, string BLayout, int Satfinite,
+class NVVM_MMA<string ALayout, string BLayout, int Satfinite, string b1op,
                WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
   : Intrinsic<D.regs,
               !listconcat(A.regs, B.regs, C.regs),
               [IntrNoMem],
-              MMA_NAME<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
+              MMA_NAME<ALayout, BLayout, Satfinite, b1op, A, B, C, D>.llvm>;
 
 foreach layout_a = ["row", "col"] in {
   foreach layout_b = ["row", "col"] in {
     foreach satf = [0, 1] in {
       foreach op = NVVM_MMA_OPS.all_mma_ops in {
-        if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
-          def MMA_NAME<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>.record
-            : NVVM_MMA<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>;
-        }
+        foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
+          if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
+            def MMA_NAME<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>.record
+              : NVVM_MMA<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>;
+          }
+        } // b1op
       } // op
     } // satf
   } // layout_b

diff  --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index ab93bf16d4919..4834985b10190 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -146,6 +146,7 @@ def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">;
 def hasPTX64 : Predicate<"Subtarget->getPTXVersion() >= 64">;
 def hasPTX65 : Predicate<"Subtarget->getPTXVersion() >= 65">;
 def hasPTX70 : Predicate<"Subtarget->getPTXVersion() >= 70">;
+def hasPTX71 : Predicate<"Subtarget->getPTXVersion() >= 71">;
 
 def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;
 def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">;

diff  --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 798538410b104..de4bf2ef3055f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7796,15 +7796,24 @@ defset list<WMMA_INSTR> MMA_LDSTs  = {
   } // layout
 } // defset
 
+// B1 instruction variants need extra constraints.
+class MMA_OP_PREDICATES<WMMA_REGINFO FragA, string b1op> {
+  string Op = b1op;
+  WMMA_REGINFO Frag = FragA;
+  list<Predicate> ret = !listconcat(
+    FragA.Predicates,
+    !if(!eq(b1op, ".and.popc"), [hasSM80,hasPTX71],[])
+  );
+}
 // WMMA.MMA
 class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                WMMA_REGINFO FragC, WMMA_REGINFO FragD,
-               string ALayout, string BLayout, int Satfinite, string rnd>
-  : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, FragA, FragB, FragC, FragD>.record,
+               string ALayout, string BLayout, int Satfinite, string rnd, string b1op>
+  : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, FragA, FragB, FragC, FragD>.record,
                          [FragA.Ins, FragB.Ins, FragC.Ins]>,
     // Requires does not seem to have effect on Instruction w/o Patterns.
     // We set it here anyways and propagate to the Pat<> we construct below.
-    Requires<FragA.Predicates> {
+    Requires<MMA_OP_PREDICATES<FragA, b1op>.ret> {
   let OutOperandList = FragD.Outs;
   let InOperandList  = !con(Args, (ins MmaCode:$ptx));
   string TypeList = !cond(
@@ -7816,7 +7825,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
        # "." # FragC.ptx_elt_type,
   );
   let AsmString = "wmma.mma"
-                  # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
+                  # b1op
                   # ".sync"
                   # "${ptx:aligned}"
                   # "." # ALayout
@@ -7837,13 +7846,15 @@ defset list<WMMA_INSTR> WMMAs  = {
       foreach satf = [0, 1] in {
         foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
           foreach op = NVVM_MMA_OPS.all_wmma_ops in {
-            if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
-              def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">,
-                            WMMA_REGINFO<op[1], "wmma.mma">,
-                            WMMA_REGINFO<op[2], "wmma.mma">,
-                            WMMA_REGINFO<op[3], "wmma.mma">,
-                            layout_a, layout_b, satf, rnd>;
-            }
+            foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
+              if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
+                def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">,
+                              WMMA_REGINFO<op[1], "wmma.mma">,
+                              WMMA_REGINFO<op[2], "wmma.mma">,
+                              WMMA_REGINFO<op[3], "wmma.mma">,
+                              layout_a, layout_b, satf, rnd, b1op>;
+              }
+            } // b1op
           } // op
         } // rnd
       } // satf
@@ -7854,12 +7865,12 @@ defset list<WMMA_INSTR> WMMAs  = {
 // MMA
 class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                WMMA_REGINFO FragC, WMMA_REGINFO FragD,
-               string ALayout, string BLayout, int Satfinite>
-  : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
+               string ALayout, string BLayout, int Satfinite, string b1op>
+  : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record,
                         [FragA.Ins, FragB.Ins, FragC.Ins]>,
     // Requires does not seem to have effect on Instruction w/o Patterns.
     // We set it here anyways and propagate to the Pat<> we construct below.
-    Requires<FragA.Predicates> {
+  Requires<MMA_OP_PREDICATES<FragA, b1op>.ret> {
   let OutOperandList = FragD.Outs;
   let InOperandList  = !con(Args, (ins MmaCode:$ptx));
   string TypeList = "." # FragD.ptx_elt_type
@@ -7872,7 +7883,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # "." # BLayout
                   # !if(Satfinite, ".satfinite", "")
                   # TypeList
-                  # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") # "\n\t\t"
+                  # b1op # "\n\t\t"
                   # FragD.regstring # ",\n\t\t"
                   # FragA.regstring # ",\n\t\t"
                   # FragB.regstring # ",\n\t\t"
@@ -7884,13 +7895,15 @@ defset list<WMMA_INSTR> MMAs  = {
     foreach layout_b = ["row", "col"] in {
       foreach satf = [0, 1] in {
         foreach op = NVVM_MMA_OPS.all_mma_ops in {
-          if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
-            def : MMA<WMMA_REGINFO<op[0], "mma">,
-                      WMMA_REGINFO<op[1], "mma">,
-                      WMMA_REGINFO<op[2], "mma">,
-                      WMMA_REGINFO<op[3], "mma">,
-                      layout_a, layout_b, satf>;
-          }
+          foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
+            if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
+              def : MMA<WMMA_REGINFO<op[0], "mma">,
+                        WMMA_REGINFO<op[1], "mma">,
+                        WMMA_REGINFO<op[2], "mma">,
+                        WMMA_REGINFO<op[3], "mma">,
+                        layout_a, layout_b, satf, b1op>;
+            }
+          } // b1op
         } // op
       } // satf
     } // layout_b

diff  --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2daffd0e2cf6c..785e48ce75a24 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -55,14 +55,14 @@
 # RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
 # RUN:           | FileCheck %t-ptx65-sm_75.ll
 
-# Check all variants of instructions supported by PTX70 on SM80+
-# RUN: %python %s --ptx=70 --gpu-arch=80 > %t-ptx70-sm_80.ll
-# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
-# RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX70MMA
-# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
+# Check all variants of instructions supported by PTX71 on SM80+
+# RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll
+# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
+# RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX71MMA
+# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
 # RUN:           --check-prefixes=INTRINSICS
-# RUN: llc < %t-ptx70-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 \
-# RUN:           | FileCheck %t-ptx70-sm_80.ll
+# RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \
+# RUN:           | FileCheck %t-ptx71-sm_80.ll
 
 from __future__ import print_function
 
@@ -649,9 +649,16 @@ def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
   print(Template(mma_template).substitute(test_params))
   return (test_params["intrinsic"], test_params["instruction"])
 
+def get_b1_ops(ptx_type):
+  if ptx_type != "b1":
+    return [""]
+  if ptx_version >= 71:
+    return [".xor.popc", ".and.popc"]
+  return [".xor.popc"]
+
 def gen_wmma_mma_tests():
-  wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
-  wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
+  wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
+  wmma_instruction_template = "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
 
   generated_items=[]
 
@@ -665,29 +672,30 @@ def gen_wmma_mma_tests():
     if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
       continue
 
-    params = {
-        "aligned" : ".aligned" if ptx_version >= 63 else "",
-        "alayout" : alayout,
-        "blayout" : blayout,
-        "intrinsic_signature" : wmma_signature(op),
-        "ptx_signature" : wmma_ptx_signature(op),
-        "satf"  : satf,
-        "rnd"   : rnd,
-        "geom"  : op.a.geom,
-        "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
-    }
-
-    intrinsic_template = wmma_intrinsic_template
-    instruction_template = wmma_instruction_template
-
-    generated_items.append(common_mma_test_gen(params, op,
-      intrinsic_template, instruction_template))
+    for b1op in get_b1_ops(op.a.mma_type.ptx_type):
+      params = {
+          "aligned" : ".aligned" if ptx_version >= 63 else "",
+          "alayout" : alayout,
+          "blayout" : blayout,
+          "intrinsic_signature" : wmma_signature(op),
+          "ptx_signature" : wmma_ptx_signature(op),
+          "satf"  : satf,
+          "rnd"   : rnd,
+          "geom"  : op.a.geom,
+          "b1op"  : b1op
+      }
+
+      intrinsic_template = wmma_intrinsic_template
+      instruction_template = wmma_instruction_template
+
+      generated_items.append(common_mma_test_gen(params, op,
+                                                 intrinsic_template, instruction_template))
 
   return generated_items
 
 def gen_mma_tests():
-  mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
-  mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${mma_variant}"
+  mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
+  mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
 
   generated_items=[]
 
@@ -700,22 +708,23 @@ def gen_mma_tests():
     if not is_mma_variant_supported(op, alayout, blayout, satf):
       continue
 
-    params = {
-        "aligned" : ".aligned" if ptx_version >= 63 else "",
-        "alayout" : alayout,
-        "blayout" : blayout,
-        "intrinsic_signature" : mma_signature(op),
-        "ptx_signature" : mma_ptx_signature(op),
-        "satf"  : satf,
-        "geom"  : op.a.geom,
-        "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
-    }
+    for b1op in get_b1_ops(op.a.mma_type.ptx_type):
+      params = {
+          "aligned" : ".aligned" if ptx_version >= 63 else "",
+          "alayout" : alayout,
+          "blayout" : blayout,
+          "intrinsic_signature" : mma_signature(op),
+          "ptx_signature" : mma_ptx_signature(op),
+          "satf"  : satf,
+          "geom"  : op.a.geom,
+          "b1op"  : b1op
+      }
 
-    intrinsic_template = mma_intrinsic_template
-    instruction_template = mma_instruction_template
+      intrinsic_template = mma_intrinsic_template
+      instruction_template = mma_instruction_template
 
-    generated_items.append(common_mma_test_gen(params, op,
-      intrinsic_template, instruction_template))
+      generated_items.append(common_mma_test_gen(params, op,
+        intrinsic_template, instruction_template))
 
   return generated_items
 
@@ -810,32 +819,35 @@ def gen_check_unsupported_ops(items):
 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
 
-; PTX70MMA-DAG: mma.m8n8k4.row.col.f64
-; PTX70MMA-DAG: mma.m16n8k4.row.col.tf32
-; PTX70MMA-DAG: mma.m16n8k8.row.col.tf32
-; PTX70MMA-DAG: mma.m16n8k16.row.col.bf16
-; PTX70MMA-DAG: mma.m16n8k8.row.col.bf16
-; PTX70MMA-DAG: mma.m16n8k16.row.col.f16.f16
-; PTX70MMA-DAG: mma.m16n8k16.row.col.f32.f32
-; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
-; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
-; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
-; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
-; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
-; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
-; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
-; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
-; PTX70MMA-DAG: mma.m8n8k128.row.col.b1
-; PTX70MMA-DAG: mma.m16n8k128.row.col.b1
-; PTX70MMA-DAG: mma.m16n8k256.row.col.b1
+; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
+; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
+; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
+; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16
+; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16
+; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16
+; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32
+; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
+; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
+; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
+; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
+; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
+; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
+; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
+; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
+; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1
+; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1
+; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1
+; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1
+; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1
+; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1
 ;
 
 """)


        


More information about the llvm-commits mailing list