[llvm] [LLVM][NVPTX] Add support for ldmatrix extensions introduced in PTX 8.6 (PR #124899)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 29 00:50:22 PST 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {darker}-->
:warning: Python code formatter, darker found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
darker --check --diff -r 9326633abd0e59fc77072488ee8cded4fe83c8a1...303781cdf726e0bbfa5023d021684b441d13a270 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py llvm/test/CodeGen/NVPTX/wmma.py
``````````
</details>
<details>
<summary>
View the diff from darker here.
</summary>
``````````diff
--- wmma.py 2025-01-29 08:29:39.000000 +0000
+++ wmma.py 2025-01-29 08:49:49.099649 +0000
@@ -18,12 +18,12 @@
"f32": "float",
"f64": "double",
"s32": "i32",
"b16": "i32",
"b8": "i32",
- "b8x16.b6x16_p32" : "i32",
- "b8x16.b4x16_p64" : "i32",
+ "b8x16.b6x16_p32": "i32",
+ "b8x16.b4x16_p64": "i32",
"s8": "i32",
"u8": "i32",
"s4": "i32",
"u4": "i32",
"b1": "i32",
@@ -168,16 +168,16 @@
"m16n16:x2:b8": 4,
"m16n16:x1:b8x16.b6x16_p32": 2,
"m16n16:x2:b8x16.b6x16_p32": 4,
"m16n16:x1:b8x16.b4x16_p64": 2,
"m16n16:x2:b8x16.b4x16_p64": 4,
- "m8n16:x1:b8x16.b6x16_p32" : 1,
- "m8n16:x2:b8x16.b6x16_p32" : 2,
- "m8n16:x4:b8x16.b6x16_p32" : 4,
- "m8n16:x1:b8x16.b4x16_p64" : 1,
- "m8n16:x2:b8x16.b4x16_p64" : 2,
- "m8n16:x4:b8x16.b4x16_p64" : 4,
+ "m8n16:x1:b8x16.b6x16_p32": 1,
+ "m8n16:x2:b8x16.b6x16_p32": 2,
+ "m8n16:x4:b8x16.b6x16_p32": 4,
+ "m8n16:x1:b8x16.b4x16_p64": 1,
+ "m8n16:x2:b8x16.b4x16_p64": 2,
+ "m8n16:x4:b8x16.b4x16_p64": 4,
}.get(
"%s:%s:%s" % (geom, frag, ptx_elt_type),
{
# All other FP shape/fragment/type combinations have the same size
"a:f16": 8,
@@ -302,13 +302,19 @@
)
return [x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
def get_ldmatrix_ops():
- return (make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
- + make_ldmatrix_ops(["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"])
- + make_ldmatrix_ops(["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]))
+ return (
+ make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+ + make_ldmatrix_ops(
+ ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]
+ )
+ + make_ldmatrix_ops(
+ ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]
+ )
+ )
def is_wmma_geom_supported(geom):
# geometries for FP and ints.
if geom in ["m8n32k16", "m32n8k16"]:
@@ -350,10 +356,11 @@
elif geom in ["m16n16"]:
return ptx_version >= 86 and gpu_arch >= 100 and aa
elif geom in ["m8n16"]:
return ptx_version >= 86 and gpu_arch >= 100 and aa
assert False # Unexpected geometry.
+
def is_ldmatrix_trans_supported(geom, trans):
if geom in ["m8n8"]:
return True
elif geom in ["m16n16"]:
@@ -1042,11 +1049,11 @@
global gpu_arch
global aa
parser = argparse.ArgumentParser()
parser.add_argument("--ptx", type=int, default=60)
parser.add_argument("--gpu-arch", type=int, default=70)
- parser.add_argument("--aa", action='store_true')
+ parser.add_argument("--aa", action="store_true")
args = parser.parse_args()
ptx_version = args.ptx
gpu_arch = args.gpu_arch
aa = args.aa
``````````
</details>
https://github.com/llvm/llvm-project/pull/124899
More information about the llvm-commits
mailing list