[llvm] d774b4a - [NVPTX, CUDA] Add .and.popc variant of the b1 MMA instruction.
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 15 12:02:53 PDT 2021
Author: Artem Belevich
Date: 2021-07-15T12:02:09-07:00
New Revision: d774b4aa5eac785ffe40009091667521e183df40
URL: https://github.com/llvm/llvm-project/commit/d774b4aa5eac785ffe40009091667521e183df40
DIFF: https://github.com/llvm/llvm-project/commit/d774b4aa5eac785ffe40009091667521e183df40.diff
LOG: [NVPTX, CUDA] Add .and.popc variant of the b1 MMA instruction.
That should allow clang to compile mma.h from CUDA-11.3.
Differential Revision: https://reviews.llvm.org/D105384
Added:
Modified:
clang/include/clang/Basic/BuiltinsNVPTX.def
clang/lib/CodeGen/CGBuiltin.cpp
clang/test/CodeGen/builtins-nvptx-mma.cu
clang/test/CodeGen/builtins-nvptx-mma.py
llvm/include/llvm/IR/IntrinsicsNVVM.td
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/test/CodeGen/NVPTX/wmma.py
Removed:
################################################################################
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def
index e815138a15c15..3c96900136a40 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.def
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -724,6 +724,7 @@ TARGET_BUILTIN(__hmma_m8n32k16_mma_f16f32, "vi*iC*iC*fC*IiIi", "", AND(SM_70,PTX
TARGET_BUILTIN(__bmma_m8n8k128_ld_a_b1, "vi*iC*UiIi", "", AND(SM_75,PTX63))
TARGET_BUILTIN(__bmma_m8n8k128_ld_b_b1, "vi*iC*UiIi", "", AND(SM_75,PTX63))
TARGET_BUILTIN(__bmma_m8n8k128_ld_c, "vi*iC*UiIi", "", AND(SM_75,PTX63))
+TARGET_BUILTIN(__bmma_m8n8k128_mma_and_popc_b1, "vi*iC*iC*iC*Ii", "", AND(SM_75,PTX71))
TARGET_BUILTIN(__bmma_m8n8k128_mma_xor_popc_b1, "vi*iC*iC*iC*Ii", "", AND(SM_75,PTX63))
TARGET_BUILTIN(__bmma_m8n8k128_st_c_i32, "vi*iC*UiIi", "", AND(SM_75,PTX63))
TARGET_BUILTIN(__imma_m16n16k16_ld_a_s8, "vi*iC*UiIi", "", AND(SM_72,PTX63))
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 0635be425e0aa..4091b6cc62ce9 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -16630,9 +16630,18 @@ static NVPTXMmaInfo getNVPTXMmaInfo(unsigned BuiltinID) {
0, \
0
// b1 MMA does not support .satfinite.
-#define MMA_VARIANTS_B1(geom, type) \
+#define MMA_VARIANTS_B1_XOR(geom, type) \
0, \
- Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_xor_popc_row_col_##type, \
+ 0, \
+ 0, \
+ 0, \
+ 0, \
+ 0, \
+ 0
+#define MMA_VARIANTS_B1_AND(geom, type) \
+ 0, \
+ Intrinsic::nvvm_wmma_##geom##_mma_and_popc_row_col_##type, \
0, \
0, \
0, \
@@ -16689,7 +16698,9 @@ static NVPTXMmaInfo getNVPTXMmaInfo(unsigned BuiltinID) {
case NVPTX::BI__imma_m8n8k32_mma_u4:
return {1, 1, 2, 2, {{MMA_VARIANTS_I4(m8n8k32, u4)}}};
case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1:
- return {1, 1, 2, 2, {{MMA_VARIANTS_B1(m8n8k128, b1)}}};
+ return {1, 1, 2, 2, {{MMA_VARIANTS_B1_XOR(m8n8k128, b1)}}};
+ case NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1:
+ return {1, 1, 2, 2, {{MMA_VARIANTS_B1_AND(m8n8k128, b1)}}};
// Double MMA
case NVPTX::BI__dmma_m8n8k4_mma_f64:
@@ -16710,7 +16721,8 @@ static NVPTXMmaInfo getNVPTXMmaInfo(unsigned BuiltinID) {
#undef MMA_VARIANTS
#undef MMA_SATF_VARIANTS
#undef MMA_VARIANTS_I4
-#undef MMA_VARIANTS_B1
+#undef MMA_VARIANTS_B1_AND
+#undef MMA_VARIANTS_B1_XOR
}
} // namespace
@@ -17119,6 +17131,7 @@ CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
case NVPTX::BI__imma_m8n8k32_mma_s4:
case NVPTX::BI__imma_m8n8k32_mma_u4:
case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1:
+ case NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1:
case NVPTX::BI__dmma_m8n8k4_mma_f64:
case NVPTX::BI__mma_bf16_m16n16k16_mma_f32:
case NVPTX::BI__mma_bf16_m8n32k16_mma_f32:
@@ -17136,7 +17149,8 @@ CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
if (Layout < 0 || Layout > 3)
return nullptr;
llvm::APSInt SatfArg;
- if (BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1)
+ if (BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1 ||
+ BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_and_popc_b1)
SatfArg = 0; // .b1 does not have satf argument.
else if (Optional<llvm::APSInt> OptSatfArg =
E->getArg(5)->getIntegerConstantExpr(getContext()))
diff --git a/clang/test/CodeGen/builtins-nvptx-mma.cu b/clang/test/CodeGen/builtins-nvptx-mma.cu
index 7e9bac86792d2..aaa44bcaa7e22 100644
--- a/clang/test/CodeGen/builtins-nvptx-mma.cu
+++ b/clang/test/CodeGen/builtins-nvptx-mma.cu
@@ -3,20 +3,21 @@
// *** DO NOT EDIT ***
//
// This test has been automatically generated by
-// builtins-nvtx-mma.py --ptx=70 --gpu-arch=80
+// builtins-nvtx-mma.py --ptx=71 --gpu-arch=80
//
-// Make sure we can handle all builtins available on sm_80 with PTX70
+// Make sure we can handle all builtins available on sm_80 with PTX71
// RUN: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_80 \
-// RUN: -fcuda-is-device -target-feature +ptx70 \
-// RUN: -DPTX=70 -DSM=80 \
+// RUN: -fcuda-is-device -target-feature +ptx71 \
+// RUN: -DPTX=71 -DSM=80 \
// RUN: -S -emit-llvm -o - -x cuda %s \
-// RUN: | FileCheck -check-prefixes=CHECK_PTX70_SM80,CHECK_PTX60_SM70,CHECK_PTX63_SM72,CHECK_PTX61_SM70,CHECK_PTX63_SM75 %s
+// RUN: | FileCheck -check-prefixes=CHECK_PTX70_SM80,CHECK_PTX60_SM70,CHECK_PTX63_SM72,CHECK_PTX61_SM70,CHECK_PTX63_SM75,CHECK_PTX71_SM75 %s
// Verify that all builtins have correct constraints.
// RUN: %clang_cc1 -triple nvptx-unknown-unknown \
// RUN: -target-cpu sm_60 -target-feature +ptx42 \
-// RUN: -DPTX=70 -DSM=80 -fcuda-is-device -S -o /dev/null -x cuda \
+// RUN: -DPTX=71 -DSM=80 -fcuda-is-device -S -o /dev/null -x cuda \
// RUN: -verify %s
+
#if !defined(CUDA_VERSION)
#define __device__ __attribute__((device))
#define __global__ __attribute__((global))
@@ -31,6 +32,7 @@ __device__ void test_wmma_buitins(int *src, int *dst,
float *fsrc, float *fdst,
double *dsrc, double *ddst, int ldm) {
+
#if (PTX >= 60) && (SM >= 70)
// CHECK_PTX60_SM70: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16
@@ -735,7 +737,7 @@ __device__ void test_wmma_buitins(int *src, int *dst,
// CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.store.d.row.stride.s32
// expected-error-re at +1 {{'__imma_m8n8k32_st_c_i32' needs target feature (sm_75{{.*}},(ptx63{{.*}}}}
__imma_m8n8k32_st_c_i32(dst, src, ldm, 0);
- // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.row.col.b1
+ // CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.xor.popc.row.col.b1
// expected-error-re at +1 {{'__bmma_m8n8k128_mma_xor_popc_b1' needs target feature (sm_75{{.*}},(ptx63{{.*}}}}
__bmma_m8n8k128_mma_xor_popc_b1(dst, src, src, src, 1);
// CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.mma.row.col.s4
@@ -750,7 +752,7 @@ __device__ void test_wmma_buitins(int *src, int *dst,
// CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.mma.row.col.u4.satfinite
// expected-error-re at +1 {{'__imma_m8n8k32_mma_u4' needs target feature (sm_75{{.*}},(ptx63{{.*}}}}
__imma_m8n8k32_mma_u4(dst, src, src, src, 1, 1);
-#endif // (PTX >= 63) && (SM >= 75)
+#endif // (PTX >= 63) && (SM >= 75)
#if (PTX >= 70) && (SM >= 80)
@@ -898,5 +900,12 @@ __device__ void test_wmma_buitins(int *src, int *dst,
// CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64
// expected-error-re at +1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
__dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 0, 0);
-#endif // (PTX >= 70) && (SM >= 80)
+#endif // (PTX >= 70) && (SM >= 80)
+
+#if (PTX >= 71) && (SM >= 75)
+
+ // CHECK_PTX71_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k128.mma.and.popc.row.col.b1
+ // expected-error-re at +1 {{'__bmma_m8n8k128_mma_and_popc_b1' needs target feature (sm_75{{.*}},(ptx71{{.*}}}}
+ __bmma_m8n8k128_mma_and_popc_b1(dst, src, src, src, 1);
+#endif // (PTX >= 71) && (SM >= 75)
}
diff --git a/clang/test/CodeGen/builtins-nvptx-mma.py b/clang/test/CodeGen/builtins-nvptx-mma.py
index 2ffc21b12fb06..dc40f04c11ce6 100644
--- a/clang/test/CodeGen/builtins-nvptx-mma.py
+++ b/clang/test/CodeGen/builtins-nvptx-mma.py
@@ -22,24 +22,29 @@ def __repr__(self):
return "%s:%s:%s" % (self.geom, self.frag, self.ptx_type)
class MMAOp:
- def __init__(self, a, b, c, d):
+ def __init__(self, a, b, c, d, b1op=""):
self.a = a
self.b = b
self.c = c
self.d = d
+ self.b1op = b1op
def __repr__(self):
return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d ))
-def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
+def make_mma_ops(geoms, types_a, types_b, types_c, types_d, b1ops=None):
ops = []
+ if b1ops is None:
+ b1ops = [""]
for geom, type_a, type_c in product( geoms, types_a, types_c):
for type_b, type_d in product(types_b if types_b else [type_a],
types_d if types_d else [type_c]):
- ops.append(MMAOp(MMAFrag(geom, "a", type_a),
- MMAFrag(geom, "b", type_b),
- MMAFrag(geom, "c", type_c),
- MMAFrag(geom, "d", type_d)))
+ ops += [
+ MMAOp(MMAFrag(geom, "a", type_a),
+ MMAFrag(geom, "b", type_b),
+ MMAFrag(geom, "c", type_c),
+ MMAFrag(geom, "d", type_d), b1op)
+ for b1op in b1ops]
return ops
def make_ldst_ops(geoms, frags, types):
@@ -60,9 +65,12 @@ def get_mma_ops():
make_mma_ops(["m8n8k32"],
["s4", "u4"], [], ["s32"], []) +
make_mma_ops(["m8n8k128"],
- ["b1"], [], ["s32"], []))
+ ["b1"], [], ["s32"], [],
+ [".xor.popc", ".and.popc"]))
def get_ldst_ops():
+ # NOTE: fragemts are from the point of view of PTX.
+ # fragment `d` is only for store ops, others for both loads and stores.
return (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["a", "b"], ["f16", "u8", "s8", "bf16"]) +
make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
@@ -71,8 +79,11 @@ def get_ldst_ops():
make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) +
make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) +
- make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) +
- make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
+ # TF32 m16n16k8 is odd.
+ # For fragment 'C' it uses __mma_*tf32*_m16n16k8_ld_c
+ # but 'D' calls __mma_m16n16k8_st_c_*f32*.
+ make_ldst_ops(["m16n16k8"], ["a", "b", "c"], ["tf32"]) +
+ make_ldst_ops(["m16n16k8"], ["d"], ["f32"]))
def is_geom_supported(geom):
# geometries for FP and ints.
@@ -180,15 +191,19 @@ def get_mma_builtin_name(op):
else:
suffix = op.a.ptx_type
- name = "%s_%s_mma%s_%s" % (prefix, op.a.geom,
- "_xor_popc" if op.a.ptx_type == "b1" else "",
- suffix)
+ name = "{prefix}_{geom}_mma{b1op}_{suffix}".format(
+ prefix = prefix,
+ geom = op.a.geom,
+ b1op = op.b1op.replace(".","_"),
+ suffix = suffix)
return name
-def get_required_sm(frag):
+def get_required_sm(frag, b1op=""):
if frag.ptx_type in ["f64", "bf16", "tf32"]:
return 80
if frag.ptx_type in ["u4", "s4", "b1"]:
+ if b1op == "_and_popc":
+ return 80
return 75
if frag.ptx_type in ["s8", "u8"]:
return 72
@@ -204,7 +219,9 @@ def get_required_sm(frag):
return 70
assert(False)
-def get_required_ptx(frag):
+def get_required_ptx(frag, b1op=""):
+ if frag.ptx_type == "b1" and b1op == ".and.popc":
+ return 71
if frag.ptx_type in ["f64", "bf16", "tf32"]:
return 70
if frag.ptx_type in ["f16", "f32"]:
@@ -215,11 +232,13 @@ def get_required_ptx(frag):
return 61
return 63
-def get_src_dst_prefix(ptx_type):
- if ptx_type == "f32":
+def get_src_dst_prefix(frag):
+ if frag.ptx_type == "f32":
return "f"
- if ptx_type == "f64":
+ if frag.ptx_type == "f64":
return "d"
+ if frag.ptx_type == "tf32" and frag.frag in ["c", "d"]:
+ return "f"
return ""
def gen_wmma_ldst_tests(results):
@@ -235,9 +254,17 @@ def gen_wmma_ldst_tests(results):
if not is_ldst_variant_supported(frag, layout):
continue
- src_dst_prefix = get_src_dst_prefix(frag.ptx_type)
+ src_dst_prefix = get_src_dst_prefix(frag)
+
min_sm = get_required_sm(frag)
min_ptx = get_required_ptx(frag)
+ # TF32 uses f32 for accumulator loads.
+ if frag.geom == "m16n16k8" and frag.frag =="c":
+ assert frag.ptx_type == "tf32"
+ itype = "f32"
+ else:
+ itype = frag.ptx_type
+
params = {
"check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm),
"builtin" : get_ldst_builtin_name(frag),
@@ -250,7 +277,7 @@ def gen_wmma_ldst_tests(results):
"frag" : frag.frag,
"geom" : frag.geom,
"ilayout" : layout,
- "itype" : frag.ptx_type,
+ "itype" : itype,
"op" : "store" if frag.frag == "d" else "load",
})
}
@@ -283,7 +310,7 @@ def gen_wmma_mma_tests(results):
// expected-error-re at +1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_satf});
""".rstrip()
- intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
+ intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}.${intrinsic_signature}${satf}"
for op, alayout, blayout, satf in sorted(product( get_mma_ops(),
["row","col"],
@@ -294,15 +321,15 @@ def gen_wmma_mma_tests(results):
if not is_mma_variant_supported(op, alayout, blayout, satf):
continue
- asrc_prefix = get_src_dst_prefix(op.a.ptx_type)
- csrc_prefix = get_src_dst_prefix(op.c.ptx_type)
- ddst_prefix = get_src_dst_prefix(op.d.ptx_type)
- min_sm = get_required_sm(op.a)
- min_ptx = get_required_ptx(op.a)
+ asrc_prefix = get_src_dst_prefix(op.a)
+ csrc_prefix = get_src_dst_prefix(op.c)
+ ddst_prefix = get_src_dst_prefix(op.d)
if op.a.ptx_type == "b1": # .b1 MMA has no satf argument.
isatf_arg = ""
else:
isatf_arg = ", 1" if satf else ", 0"
+ min_sm = get_required_sm(op.a, op.b1op)
+ min_ptx = get_required_ptx(op.a, op.b1op)
params = {
"check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm),
"builtin" : get_mma_builtin_name(op),
@@ -319,6 +346,7 @@ def gen_wmma_mma_tests(results):
"blayout" : blayout,
"intrinsic_signature" : mma_signature(op),
"satf" : satf,
+ "b1op" : op.b1op
})
}
results[(min_ptx, min_sm)] += Template(mma_template).substitute(params)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 3ce9dfb1bb807..cc43d23bec1ce 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -225,12 +225,13 @@ class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type));
}
-class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd,
+class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, string b1op,
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
string llvm = "llvm.nvvm.wmma."
# A.geom
# ".mma"
+ # b1op
# "." # ALayout
# "." # BLayout
# !if(!ne(Rnd, ""), !strconcat(".", Rnd), "")
@@ -241,11 +242,12 @@ class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd,
!subst("llvm.", "int_", llvm));
}
-class MMA_NAME<string ALayout, string BLayout, int Satfinite,
+class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
- string llvm = "llvm.nvvm.mma."
- # A.geom
+ string llvm = "llvm.nvvm.mma"
+ # b1op
+ # "." # A.geom
# "." # ALayout
# "." # BLayout
# !if(Satfinite, ".satfinite", "")
@@ -430,6 +432,13 @@ class NVVM_WMMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_
);
}
+class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
+ list<string> ret = !cond(
+ !eq(frags[0].ptx_elt_type, "b1") : [".xor.popc", ".and.popc"],
+ true: [""]
+ );
+}
+
// Returns true if this combination of layout/satf for MMA ops is supported;
// false otherwise.
// E.g.
@@ -4460,25 +4469,27 @@ foreach layout = ["row", "col"] in {
}
// WMMA.MMA
-class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite, string rnd,
+class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite, string rnd, string b1op,
WMMA_REGS A, WMMA_REGS B,
WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
!listconcat(A.regs, B.regs, C.regs),
[IntrNoMem],
- WMMA_NAME<ALayout, BLayout, Satfinite, rnd, A, B, C, D>.llvm>;
+ WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, A, B, C, D>.llvm>;
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
foreach satf = [0, 1] in {
foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
foreach op = NVVM_MMA_OPS.all_wmma_ops in {
- if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
- def WMMA_NAME<layout_a, layout_b, satf, rnd,
- op[0], op[1], op[2], op[3]>.record
- : NVVM_WMMA_MMA<layout_a, layout_b, satf, rnd,
- op[0], op[1], op[2], op[3]>;
- }
+ foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
+ if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
+ def WMMA_NAME<layout_a, layout_b, satf, rnd, b1op,
+ op[0], op[1], op[2], op[3]>.record
+ : NVVM_WMMA_MMA<layout_a, layout_b, satf, rnd, b1op,
+ op[0], op[1], op[2], op[3]>;
+ }
+ } // b1op
} // op
} // rnd
} // satf
@@ -4486,21 +4497,23 @@ foreach layout_a = ["row", "col"] in {
} // layout_a
// MMA
-class NVVM_MMA<string ALayout, string BLayout, int Satfinite,
+class NVVM_MMA<string ALayout, string BLayout, int Satfinite, string b1op,
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
!listconcat(A.regs, B.regs, C.regs),
[IntrNoMem],
- MMA_NAME<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
+ MMA_NAME<ALayout, BLayout, Satfinite, b1op, A, B, C, D>.llvm>;
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
- if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
- def MMA_NAME<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>.record
- : NVVM_MMA<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>;
- }
+ foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
+ if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
+ def MMA_NAME<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>.record
+ : NVVM_MMA<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>;
+ }
+ } // b1op
} // op
} // satf
} // layout_b
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index ab93bf16d4919..4834985b10190 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -146,6 +146,7 @@ def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">;
def hasPTX64 : Predicate<"Subtarget->getPTXVersion() >= 64">;
def hasPTX65 : Predicate<"Subtarget->getPTXVersion() >= 65">;
def hasPTX70 : Predicate<"Subtarget->getPTXVersion() >= 70">;
+def hasPTX71 : Predicate<"Subtarget->getPTXVersion() >= 71">;
def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;
def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 798538410b104..de4bf2ef3055f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7796,15 +7796,24 @@ defset list<WMMA_INSTR> MMA_LDSTs = {
} // layout
} // defset
+// B1 instruction variants need extra constraints.
+class MMA_OP_PREDICATES<WMMA_REGINFO FragA, string b1op> {
+ string Op = b1op;
+ WMMA_REGINFO Frag = FragA;
+ list<Predicate> ret = !listconcat(
+ FragA.Predicates,
+ !if(!eq(b1op, ".and.popc"), [hasSM80,hasPTX71],[])
+ );
+}
// WMMA.MMA
class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
- string ALayout, string BLayout, int Satfinite, string rnd>
- : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, FragA, FragB, FragC, FragD>.record,
+ string ALayout, string BLayout, int Satfinite, string rnd, string b1op>
+ : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, FragA, FragB, FragC, FragD>.record,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
- Requires<FragA.Predicates> {
+ Requires<MMA_OP_PREDICATES<FragA, b1op>.ret> {
let OutOperandList = FragD.Outs;
let InOperandList = !con(Args, (ins MmaCode:$ptx));
string TypeList = !cond(
@@ -7816,7 +7825,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# "." # FragC.ptx_elt_type,
);
let AsmString = "wmma.mma"
- # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
+ # b1op
# ".sync"
# "${ptx:aligned}"
# "." # ALayout
@@ -7837,13 +7846,15 @@ defset list<WMMA_INSTR> WMMAs = {
foreach satf = [0, 1] in {
foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
foreach op = NVVM_MMA_OPS.all_wmma_ops in {
- if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
- def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">,
- WMMA_REGINFO<op[1], "wmma.mma">,
- WMMA_REGINFO<op[2], "wmma.mma">,
- WMMA_REGINFO<op[3], "wmma.mma">,
- layout_a, layout_b, satf, rnd>;
- }
+ foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
+ if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
+ def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">,
+ WMMA_REGINFO<op[1], "wmma.mma">,
+ WMMA_REGINFO<op[2], "wmma.mma">,
+ WMMA_REGINFO<op[3], "wmma.mma">,
+ layout_a, layout_b, satf, rnd, b1op>;
+ }
+ } // b1op
} // op
} // rnd
} // satf
@@ -7854,12 +7865,12 @@ defset list<WMMA_INSTR> WMMAs = {
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
- string ALayout, string BLayout, int Satfinite>
- : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
+ string ALayout, string BLayout, int Satfinite, string b1op>
+ : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
- Requires<FragA.Predicates> {
+ Requires<MMA_OP_PREDICATES<FragA, b1op>.ret> {
let OutOperandList = FragD.Outs;
let InOperandList = !con(Args, (ins MmaCode:$ptx));
string TypeList = "." # FragD.ptx_elt_type
@@ -7872,7 +7883,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# "." # BLayout
# !if(Satfinite, ".satfinite", "")
# TypeList
- # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") # "\n\t\t"
+ # b1op # "\n\t\t"
# FragD.regstring # ",\n\t\t"
# FragA.regstring # ",\n\t\t"
# FragB.regstring # ",\n\t\t"
@@ -7884,13 +7895,15 @@ defset list<WMMA_INSTR> MMAs = {
foreach layout_b = ["row", "col"] in {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
- if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
- def : MMA<WMMA_REGINFO<op[0], "mma">,
- WMMA_REGINFO<op[1], "mma">,
- WMMA_REGINFO<op[2], "mma">,
- WMMA_REGINFO<op[3], "mma">,
- layout_a, layout_b, satf>;
- }
+ foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
+ if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
+ def : MMA<WMMA_REGINFO<op[0], "mma">,
+ WMMA_REGINFO<op[1], "mma">,
+ WMMA_REGINFO<op[2], "mma">,
+ WMMA_REGINFO<op[3], "mma">,
+ layout_a, layout_b, satf, b1op>;
+ }
+ } // b1op
} // op
} // satf
} // layout_b
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2daffd0e2cf6c..785e48ce75a24 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -55,14 +55,14 @@
# RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
# RUN: | FileCheck %t-ptx65-sm_75.ll
-# Check all variants of instructions supported by PTX70 on SM80+
-# RUN: %python %s --ptx=70 --gpu-arch=80 > %t-ptx70-sm_80.ll
-# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
-# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX70MMA
-# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
+# Check all variants of instructions supported by PTX71 on SM80+
+# RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll
+# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
+# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX71MMA
+# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
# RUN: --check-prefixes=INTRINSICS
-# RUN: llc < %t-ptx70-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 \
-# RUN: | FileCheck %t-ptx70-sm_80.ll
+# RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \
+# RUN: | FileCheck %t-ptx71-sm_80.ll
from __future__ import print_function
@@ -649,9 +649,16 @@ def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
print(Template(mma_template).substitute(test_params))
return (test_params["intrinsic"], test_params["instruction"])
+def get_b1_ops(ptx_type):
+ if ptx_type != "b1":
+ return [""]
+ if ptx_version >= 71:
+ return [".xor.popc", ".and.popc"]
+ return [".xor.popc"]
+
def gen_wmma_mma_tests():
- wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
- wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
+ wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
+ wmma_instruction_template = "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
generated_items=[]
@@ -665,29 +672,30 @@ def gen_wmma_mma_tests():
if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
continue
- params = {
- "aligned" : ".aligned" if ptx_version >= 63 else "",
- "alayout" : alayout,
- "blayout" : blayout,
- "intrinsic_signature" : wmma_signature(op),
- "ptx_signature" : wmma_ptx_signature(op),
- "satf" : satf,
- "rnd" : rnd,
- "geom" : op.a.geom,
- "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
- }
-
- intrinsic_template = wmma_intrinsic_template
- instruction_template = wmma_instruction_template
-
- generated_items.append(common_mma_test_gen(params, op,
- intrinsic_template, instruction_template))
+ for b1op in get_b1_ops(op.a.mma_type.ptx_type):
+ params = {
+ "aligned" : ".aligned" if ptx_version >= 63 else "",
+ "alayout" : alayout,
+ "blayout" : blayout,
+ "intrinsic_signature" : wmma_signature(op),
+ "ptx_signature" : wmma_ptx_signature(op),
+ "satf" : satf,
+ "rnd" : rnd,
+ "geom" : op.a.geom,
+ "b1op" : b1op
+ }
+
+ intrinsic_template = wmma_intrinsic_template
+ instruction_template = wmma_instruction_template
+
+ generated_items.append(common_mma_test_gen(params, op,
+ intrinsic_template, instruction_template))
return generated_items
def gen_mma_tests():
- mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
- mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${mma_variant}"
+ mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
+ mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
generated_items=[]
@@ -700,22 +708,23 @@ def gen_mma_tests():
if not is_mma_variant_supported(op, alayout, blayout, satf):
continue
- params = {
- "aligned" : ".aligned" if ptx_version >= 63 else "",
- "alayout" : alayout,
- "blayout" : blayout,
- "intrinsic_signature" : mma_signature(op),
- "ptx_signature" : mma_ptx_signature(op),
- "satf" : satf,
- "geom" : op.a.geom,
- "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
- }
+ for b1op in get_b1_ops(op.a.mma_type.ptx_type):
+ params = {
+ "aligned" : ".aligned" if ptx_version >= 63 else "",
+ "alayout" : alayout,
+ "blayout" : blayout,
+ "intrinsic_signature" : mma_signature(op),
+ "ptx_signature" : mma_ptx_signature(op),
+ "satf" : satf,
+ "geom" : op.a.geom,
+ "b1op" : b1op
+ }
- intrinsic_template = mma_intrinsic_template
- instruction_template = mma_instruction_template
+ intrinsic_template = mma_intrinsic_template
+ instruction_template = mma_instruction_template
- generated_items.append(common_mma_test_gen(params, op,
- intrinsic_template, instruction_template))
+ generated_items.append(common_mma_test_gen(params, op,
+ intrinsic_template, instruction_template))
return generated_items
@@ -810,32 +819,35 @@ def gen_check_unsupported_ops(items):
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
-; PTX70MMA-DAG: mma.m8n8k4.row.col.f64
-; PTX70MMA-DAG: mma.m16n8k4.row.col.tf32
-; PTX70MMA-DAG: mma.m16n8k8.row.col.tf32
-; PTX70MMA-DAG: mma.m16n8k16.row.col.bf16
-; PTX70MMA-DAG: mma.m16n8k8.row.col.bf16
-; PTX70MMA-DAG: mma.m16n8k16.row.col.f16.f16
-; PTX70MMA-DAG: mma.m16n8k16.row.col.f32.f32
-; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
-; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
-; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
-; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
-; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
-; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
-; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
-; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
-; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
-; PTX70MMA-DAG: mma.m8n8k128.row.col.b1
-; PTX70MMA-DAG: mma.m16n8k128.row.col.b1
-; PTX70MMA-DAG: mma.m16n8k256.row.col.b1
+; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
+; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
+; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
+; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16
+; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16
+; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16
+; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32
+; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
+; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
+; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
+; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
+; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
+; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
+; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
+; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
+; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
+; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1
+; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1
+; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1
+; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1
+; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1
+; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1
;
""")
More information about the llvm-commits
mailing list